diff --git a/.github/workflows/gpu_unit_tests.yml b/.github/workflows/gpu_unit_tests.yml index 54795d49313..df96a720988 100644 --- a/.github/workflows/gpu_unit_tests.yml +++ b/.github/workflows/gpu_unit_tests.yml @@ -113,7 +113,7 @@ jobs: pip3 install --ignore-installed mlflow "numpy<2.0" - name: Run all GPU unit tests run: | - pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" tests/ + pytest -s -x --ignore-glob="*on_npu.py" --ignore-glob="*test_special_*.py" --ignore-glob='*on_cpu.py' --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob='tests/special*' --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" tests/ - name: Testing LinearCrossEntropyTP Correctness, Computation Time and Memory Consumption run: | LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_special_linear_cross_entropy_tp.py diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index 2a269638cf2..b9d34a0ad28 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -113,6 +113,7 @@ jobs: fetch-depth: 0 - name: Install the current repository run: | + pip3 install cupy-cuda12x pytest-asyncio pip3 install hf_transfer fastmcp pytest-asyncio pip3 install -r requirements-test.txt pip3 install --no-deps -e . @@ -124,9 +125,36 @@ jobs: run: | ROLLOUT_NAME=sglang pytest -svvv tests/experimental/agent_loop + sgl_checkpoint_engine: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 35 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: 1 + SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True" + NCCL_SHM_DISABLE: "1" + NCCL_P2P_DISABLE: "1" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install cupy-cuda12x pytest-asyncio + pip3 install hf_transfer fastmcp pytest-asyncio + pip3 install -r requirements-test.txt + pip3 install --no-deps -e . + - name: Test SGLang ServerAdapter with Checkpoint Engine (NCCL) + run: | + ROLLOUT_NAME=sglang pytest -svvv tests/checkpoint_engine/test_special_server_adapter.py + cleanup: runs-on: ubuntu-latest - needs: [setup, sgl] + needs: [setup, sgl, sgl_checkpoint_engine] if: always() steps: - id: destroy-runner diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 353e0aa2048..b9ca50693a2 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -126,11 +126,33 @@ jobs: - name: Test vllm server abort functionality run: | pytest tests/workers/rollout/rollout_vllm/test_vllm_abort.py -v -s - # Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests + + vllm_checkpoint_engine: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 35 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install cupy-cuda12x pytest-asyncio + pip3 install -r requirements-test.txt + pip3 install --no-deps -e . + - name: Test vLLM ServerAdapter with Checkpoint Engine (NCCL) + run: | + ROLLOUT_NAME=vllm pytest -svvv tests/checkpoint_engine/test_special_server_adapter.py cleanup: runs-on: ubuntu-latest - needs: [setup, vllm] + needs: [setup, vllm, vllm_checkpoint_engine] if: always() steps: - id: destroy-runner diff --git a/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh index ea41e169613..62bdb1cb2cc 100644 --- a/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh +++ b/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh @@ -72,7 +72,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.rollout.max_num_seqs=${MAX_BATCH_SIZE} \ actor_rollout_ref.rollout.max_num_batched_tokens=32768 \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=4096 \ + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=4096 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${ACTOR_TP} \ +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \ diff --git a/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh b/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh index e61be265bfb..bcb95d9b822 100644 --- a/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh +++ b/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh @@ -75,7 +75,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.calculate_log_probs=True \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=4096 \ + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=4096 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ diff --git a/tests/checkpoint_engine/test_correctness_on_gpu.py b/tests/checkpoint_engine/test_correctness_on_gpu.py new file mode 100644 index 00000000000..ff4a959b20f --- /dev/null +++ b/tests/checkpoint_engine/test_correctness_on_gpu.py @@ -0,0 +1,139 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import ray + +from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group +from verl.checkpoint_engine import CheckpointEngineManager +from verl.single_controller.ray.base import ( + RayResourcePool, + split_resource_pool, +) +from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig + + +@pytest.mark.asyncio +@pytest.mark.parametrize("rebuild_group", [False, True]) +@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +async def test_nccl_checkpoint_engine( + rebuild_group, + num_trainer, + num_rollout, + num_nodes=1, + num_gpus_per_node=8, + check_allclose=True, + model_path="~/models/Qwen/Qwen3-8B-Base", +): + model_path = os.path.expanduser(model_path) + ray.init( + runtime_env={ + "env_vars": { + "UCX_TLS": "rc,tcp,cuda", + "UCX_MAX_RNDV_RAILS": "4", + "UCX_LOG_LEVEL": "INFO", + "VERL_LOGGING_LEVEL": "DEBUG", + } + } + ) + + # initialize config + checkpoint_engine_config = CheckpointEngineConfig( + backend="nccl", engine_kwargs={"nccl": {"rebuild_group": rebuild_group}} + ) + model_config = HFModelConfig(path=model_path, use_remove_padding=True) + rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) + + # create trainer and rollout worker group + resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) + trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) + trainer.reset() + rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) + + # create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager(backend="nccl", trainer=trainer, replicas=replicas) + for _ in range(3): + await checkpoint_manager.update_weights() + rollout.check_weights() + + ray.shutdown() + + +@pytest.mark.skip(reason="temporary skip since our ci environment is not ready") +@pytest.mark.asyncio +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +async def test_nixl_checkpoint_engine( + num_trainer, + num_rollout, + device, + num_nodes=1, + num_gpus_per_node=8, + check_allclose=True, + model_path="~/models/Qwen/Qwen3-8B-Base", +): + model_path = os.path.expanduser(model_path) + ray.init( + runtime_env={ + "env_vars": { + # TODO: it's pretty hard to set these environment variables right, please consult + # with your network admin. Maybe auto adjust UCX_* according to NCCL_IB_*? + "UCX_TLS": "rc,ud,cuda", + # "UCX_IB_GID_INDEX": "3", # NCCL_IB_GID_INDEX + # "UCX_IB_DEVICES": "mlx5_1:1,mlx5_2:1,mlx5_3:1", # NCCL_IB_HCA + "UCX_RC_TIMEOUT": "30s", # NCCL_IB_TIMEOUT + "UCX_RC_RETRY_COUNT": "7", # NCCL_IB_RETRY_COUNT + "UCX_KEEPALIVE_INTERVAL": "1s", + "UCX_KEEPALIVE_NUM_EPS": "10", + "UCX_MAX_RNDV_RAILS": "4", + "UCX_IB_ROCE_REACHABILITY_MODE": "all", + "UCX_LOG_LEVEL": "INFO", + "VERL_LOGGING_LEVEL": "DEBUG", + } + } + ) + + # initialize config + checkpoint_engine_config = CheckpointEngineConfig(backend="nixl", engine_kwargs={"nixl": {"device": device}}) + model_config = HFModelConfig(path=model_path, use_remove_padding=True) + rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) + + # create trainer and rollout worker group + resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) + trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) + trainer.reset() + rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) + + # create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager(backend="nixl", trainer=trainer, replicas=replicas) + for _ in range(3): + await checkpoint_manager.update_weights() + rollout.check_weights() + + ray.shutdown() + + +if __name__ == "__main__": + test_nccl_checkpoint_engine( + rebuild_group=False, + num_trainer=2, + num_rollout=30, + num_nodes=4, + num_gpus_per_node=8, + check_allclose=False, + model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base", + ) diff --git a/tests/checkpoint_engine/test_hccl_checkpoint_engine.py b/tests/checkpoint_engine/test_correctness_on_npu.py similarity index 51% rename from tests/checkpoint_engine/test_hccl_checkpoint_engine.py rename to tests/checkpoint_engine/test_correctness_on_npu.py index 79eeeea36a3..b99fcc771be 100644 --- a/tests/checkpoint_engine/test_hccl_checkpoint_engine.py +++ b/tests/checkpoint_engine/test_correctness_on_npu.py @@ -17,17 +17,19 @@ import ray from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group +from verl.checkpoint_engine import CheckpointEngineManager from verl.single_controller.ray.base import ( RayResourcePool, split_resource_pool, ) from verl.utils.device import get_device_name +from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig -@pytest.mark.skipif(get_device_name() != "npu", reason="NPU is not available") -@pytest.mark.parametrize("rebuild_group", [False, True]) +@pytest.mark.asyncio +@pytest.mark.parametrize("rebuild_group", [False]) @pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) -def test_hccl_checkpoint_engine( +async def test_hccl_checkpoint_engine( rebuild_group, num_trainer, num_rollout, @@ -48,55 +50,25 @@ def test_hccl_checkpoint_engine( } ) + # initialize config + checkpoint_engine_config = CheckpointEngineConfig( + backend="hccl", engine_kwargs={"hccl": {"rebuild_group": rebuild_group}} + ) + model_config = HFModelConfig(path=model_path, use_remove_padding=True) + rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) + + # create trainer and rollout worker group resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) resource_pool.get_placement_groups(device_name=get_device_name()) trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) - checkpoint_kwargs = { - "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB - "rebuild_group": rebuild_group, - } - - trainer = create_trainer_worker_group(model_path, trainer_pool, "hccl", checkpoint_kwargs) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) trainer.reset() - rollout = create_rollout_worker_group( - model_path, rollout_pool, "hccl", checkpoint_kwargs, check_allclose=check_allclose - ) + rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) + # create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager(backend="hccl", trainer=trainer, replicas=replicas) for _ in range(3): - # 1. prepare all workers - metadata = ray.get( - trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) - + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) - ) - trainer_kwargs = { - "method": ["init_process_group"] * trainer.world_size, - "rank": [0] + [-1] * (trainer.world_size - 1), - "world_size": [rollout.world_size + 1] * trainer.world_size, - "master_metadata": [metadata[0]] * trainer.world_size, - } - rollout_kwargs = { - "method": ["init_process_group"] * rollout.world_size, - "rank": list(range(1, rollout.world_size + 1)), - "world_size": [rollout.world_size + 1] * rollout.world_size, - "master_metadata": [metadata[0]] * rollout.world_size, - } - - # 2. init process group between all workers - ray.get( - trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) - ) - - # 3. update weights of all workers - print("start to upate") - ray.get(trainer.update_weights() + rollout.update_weights()) - - # 4. finish all workers - ray.get( - trainer.execute_checkpoint_engine(["finish"] * trainer.world_size) - + rollout.execute_checkpoint_engine(["finish"] * rollout.world_size) - ) - print("end update") - # 5. check weights of rollout workers + await checkpoint_manager.update_weights() rollout.check_weights() ray.shutdown() diff --git a/tests/checkpoint_engine/test_nccl_checkpoint_engine.py b/tests/checkpoint_engine/test_nccl_checkpoint_engine.py deleted file mode 100644 index a04ceaaf25c..00000000000 --- a/tests/checkpoint_engine/test_nccl_checkpoint_engine.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import pytest -import ray - -from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group -from verl.single_controller.ray.base import ( - RayResourcePool, - split_resource_pool, -) -from verl.utils.device import get_device_name - - -@pytest.mark.skipif(get_device_name() != "cuda", reason="GPU is not available") -@pytest.mark.parametrize("rebuild_group", [False, True]) -@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) -def test_nccl_checkpoint_engine( - rebuild_group, - num_trainer, - num_rollout, - num_nodes=1, - num_gpus_per_node=8, - check_allclose=True, - model_path="~/models/Qwen/Qwen3-8B-Base", -): - model_path = os.path.expanduser(model_path) - ray.init( - runtime_env={ - "env_vars": { - "UCX_TLS": "rc,tcp,cuda", - "UCX_MAX_RNDV_RAILS": "4", - "UCX_LOG_LEVEL": "INFO", - "VERL_LOGGING_LEVEL": "DEBUG", - } - } - ) - - resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) - trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) - checkpoint_kwargs = { - "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB - "rebuild_group": rebuild_group, - } - - trainer = create_trainer_worker_group(model_path, trainer_pool, "nccl", checkpoint_kwargs) - trainer.reset() - rollout = create_rollout_worker_group( - model_path, rollout_pool, "nccl", checkpoint_kwargs, check_allclose=check_allclose - ) - - for _ in range(3): - # 1. prepare all workers - metadata = ray.get( - trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) - + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) - ) - trainer_kwargs = { - "method": ["init_process_group"] * trainer.world_size, - "rank": [0] + [-1] * (trainer.world_size - 1), - "world_size": [rollout.world_size + 1] * trainer.world_size, - "master_metadata": [metadata[0]] * trainer.world_size, - } - rollout_kwargs = { - "method": ["init_process_group"] * rollout.world_size, - "rank": list(range(1, rollout.world_size + 1)), - "world_size": [rollout.world_size + 1] * rollout.world_size, - "master_metadata": [metadata[0]] * rollout.world_size, - } - - # 2. init process group between all workers - ray.get( - trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) - ) - - # 3. update weights of all workers - ray.get(trainer.update_weights() + rollout.update_weights()) - - # 4. finish all workers - ray.get( - trainer.execute_checkpoint_engine(["finish"] * trainer.world_size) - + rollout.execute_checkpoint_engine(["finish"] * rollout.world_size) - ) - - # 5. check weights of rollout workers - rollout.check_weights() - - ray.shutdown() - - -if __name__ == "__main__": - test_nccl_checkpoint_engine( - rebuild_group=False, - num_trainer=2, - num_rollout=30, - num_nodes=4, - num_gpus_per_node=8, - check_allclose=False, - model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base", - ) diff --git a/tests/checkpoint_engine/test_nixl_checkpoint_engine.py b/tests/checkpoint_engine/test_nixl_checkpoint_engine.py deleted file mode 100644 index e8367e4ce4c..00000000000 --- a/tests/checkpoint_engine/test_nixl_checkpoint_engine.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import pytest -import ray - -from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group -from verl.single_controller.ray.base import ( - RayResourcePool, - split_resource_pool, -) - - -@pytest.mark.skip(reason="temporary skip since our ci environment is not ready") -@pytest.mark.parametrize("device", ["cuda", "cpu"]) -@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) -def test_nixl_checkpoint_engine( - num_trainer, - num_rollout, - device, - num_nodes=1, - num_gpus_per_node=8, - check_allclose=True, - model_path="~/models/Qwen/Qwen3-8B-Base", -): - model_path = os.path.expanduser(model_path) - ray.init( - runtime_env={ - "env_vars": { - # TODO: it's pretty hard to set these environment variables right, please consult - # with your network admin. Maybe auto adjust UCX_* according to NCCL_IB_*? - "UCX_TLS": "rc,ud,cuda", - # "UCX_IB_GID_INDEX": "3", # NCCL_IB_GID_INDEX - # "UCX_IB_DEVICES": "mlx5_1:1,mlx5_2:1,mlx5_3:1", # NCCL_IB_HCA - "UCX_RC_TIMEOUT": "30s", # NCCL_IB_TIMEOUT - "UCX_RC_RETRY_COUNT": "7", # NCCL_IB_RETRY_COUNT - "UCX_KEEPALIVE_INTERVAL": "1s", - "UCX_KEEPALIVE_NUM_EPS": "10", - "UCX_MAX_RNDV_RAILS": "4", - "UCX_LOG_LEVEL": "INFO", - "VERL_LOGGING_LEVEL": "DEBUG", - } - } - ) - - resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) - trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) - checkpoint_kwargs = { - "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB - "device": device, - } - - trainer = create_trainer_worker_group(model_path, trainer_pool, "nixl", checkpoint_kwargs) - trainer.reset() - rollout = create_rollout_worker_group( - model_path, rollout_pool, "nixl", checkpoint_kwargs, device=device, check_allclose=check_allclose - ) - - for _ in range(3): - # 1. prepare all workers - metadata = ray.get( - trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) - + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) - ) - - trainer_kwargs = { - "method": ["init_process_group"] * trainer.world_size, - "rank": [0] + [-1] * (trainer.world_size - 1), - "world_size": [rollout.world_size + 1] * trainer.world_size, - "prev_agent_metadata": [None] * trainer.world_size, - "next_agent_metadata": [metadata[-rollout.world_size]] + [None] * (trainer.world_size - 1), - } - - rollout_kwargs = { - "method": ["init_process_group"] * rollout.world_size, - "rank": list(range(1, rollout.world_size + 1)), - "world_size": [rollout.world_size + 1] * rollout.world_size, - "prev_agent_metadata": [metadata[0]] + metadata[-rollout.world_size : -1], - "next_agent_metadata": metadata[-rollout.world_size + 1 :] + [None], - } - - # 2. init process group between all workers - ray.get( - trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) - ) - - # 3. update weights of all workers - ray.get(trainer.update_weights() + rollout.update_weights()) - - # 4. finish all workers - ray.get( - trainer.execute_checkpoint_engine(["finish"] * trainer.world_size) - + rollout.execute_checkpoint_engine(["finish"] * rollout.world_size) - ) - - # 5. check weights of rollout workers - rollout.check_weights() - - ray.shutdown() - - -if __name__ == "__main__": - test_nixl_checkpoint_engine( - num_trainer=2, - num_rollout=30, - device="cuda", - num_nodes=4, - num_gpus_per_node=8, - check_allclose=False, - model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base", - ) diff --git a/tests/checkpoint_engine/test_special_server_adapter.py b/tests/checkpoint_engine/test_special_server_adapter.py new file mode 100644 index 00000000000..bf640ce353e --- /dev/null +++ b/tests/checkpoint_engine/test_special_server_adapter.py @@ -0,0 +1,120 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import os + +import pytest +import ray +from omegaconf import DictConfig +from openai import AsyncOpenAI + +from tests.checkpoint_engine.test_utils import create_trainer_worker_group +from verl.checkpoint_engine import CheckpointEngineManager, CheckpointEngineWorker +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import get_device_name +from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig +from verl.workers.rollout.replica import get_rollout_replica_class + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + config.trainer.n_gpus_per_node = 8 + config.trainer.nnodes = 1 + config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen3-VL-2B-Instruct") + config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"] + config.actor_rollout_ref.rollout.skip_tokenizer_init = False + config.actor_rollout_ref.rollout.checkpoint_engine.backend = "nccl" if get_device_name() == "cuda" else "hccl" + + return config + + +@pytest.mark.asyncio +async def test_server_adapter(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + "VLLM_DISABLE_COMPILE_CACHE": "1", + } + } + ) + + # 1. create trainer worker group + model_config: HFModelConfig = omega_conf_to_dataclass(init_config.actor_rollout_ref.model) + checkpoint_engine_config: CheckpointEngineConfig = omega_conf_to_dataclass( + init_config.actor_rollout_ref.rollout.checkpoint_engine + ) + trainer_pool = RayResourcePool(process_on_nodes=[4], max_colocate_count=3) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) + trainer.reset() + + # 2. create rollout replicas + rollout_config: RolloutConfig = omega_conf_to_dataclass(init_config.actor_rollout_ref.rollout) + + # 2.1 create checkpoint engine worker group + rollout_pool = RayResourcePool(process_on_nodes=[4], max_colocate_count=3) + ray_cls_with_init = RayClassWithInitArgs( + cls=ray.remote(CheckpointEngineWorker), + model_config=model_config, + rollout_config=rollout_config, + ) + rollout = RayWorkerGroup( + resource_pool=rollout_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name() + ) + + # 2.2 create rollout replicas + rollout_replica_class = get_rollout_replica_class(rollout_config.name) + rollout_replicas = [ + rollout_replica_class( + replica_rank=replica_rank, + config=rollout_config, + model_config=model_config, + ) + for replica_rank in range(2) + ] + await asyncio.gather(*[replica.init_hybrid(rollout) for replica in rollout_replicas]) + + # 3. create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager( + backend=checkpoint_engine_config.backend, trainer=trainer, replicas=rollout_replicas + ) + for i in range(3): + await checkpoint_manager.update_weights() + + server_addresses = rollout_replicas[i % len(rollout_replicas)].server_address + client = AsyncOpenAI( + api_key="123-abc", + base_url=f"http://{server_addresses}/v1", + ) + + completion = await client.chat.completions.create( + model=init_config.actor_rollout_ref.model.path, + messages=[{"role": "user", "content": "What can you do?"}], + ) + print("[OUTPUT]:", completion.choices[0].message.content) + + ray.shutdown() diff --git a/tests/checkpoint_engine/test_utils.py b/tests/checkpoint_engine/test_utils.py index 4e18b227d09..27cb055bd36 100644 --- a/tests/checkpoint_engine/test_utils.py +++ b/tests/checkpoint_engine/test_utils.py @@ -11,27 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +from typing import Generator import ray import torch from transformers import AutoModelForCausalLM -from verl.checkpoint_engine import CheckpointEngineRegistry +from verl.checkpoint_engine import CheckpointEngineRegistry, CheckpointEngineWorker from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.utils.device import get_device_name from verl.utils.fs import copy_to_local -from verl.workers.config import FSDPEngineConfig, HFModelConfig +from verl.workers.config import CheckpointEngineConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig from verl.workers.engine_workers import TrainingWorker, TrainingWorkerConfig +from verl.workers.rollout import BaseRollout, RolloutReplica class TrainingWorkerTest(TrainingWorker): - def __init__(self, config: TrainingWorkerConfig, checkpoint_backend: str, checkpoint_kwargs: dict) -> None: - copy_to_local(config.model_config.path) + def __init__(self, config: TrainingWorkerConfig, checkpoint_engine_config: CheckpointEngineConfig) -> None: super().__init__(config) - if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl"]: - checkpoint_kwargs["is_master"] = True - self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs) + backend = checkpoint_engine_config.backend + bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20 + engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {}) + if torch.distributed.get_rank() == 0 and backend in ["nccl", "hccl"]: + engine_kwargs["is_master"] = True + self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) async def update_weights(self): @@ -43,61 +48,88 @@ def execute_checkpoint_engine(self, method: str, *args, **kwargs): return getattr(self.checkpoint_engine, method)(*args, **kwargs) -class RolloutWorkerTest: - def __init__( - self, - model_path, - checkpoint_backend: str, - checkpoint_kwargs: dict, - device: str = "cuda", - check_allclose: bool = True, - ) -> None: - self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs) - if check_allclose: - local_path = copy_to_local(model_path) - self.model = AutoModelForCausalLM.from_pretrained(local_path, torch_dtype=torch.bfloat16) - self.model.to(device) +class MockServerAdapter(BaseRollout): + def __init__(self, config: RolloutConfig, model_config: HFModelConfig, check_allclose: bool = True): + super().__init__(config, model_config, device_mesh=None) self.check_allclose = check_allclose + self.model = None self.received_weights: dict[str, torch.Tensor] = {} - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - async def update_weights(self): - async for name, weight in self.checkpoint_engine.receive_weights(): + async def resume(self, tags: list[str]): + raise NotImplementedError() + + async def release(self): + raise NotImplementedError() + + async def update_weights( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + **kwargs, + ): + async for name, weight in weights: weight = weight.clone() if self.check_allclose: - self.received_weights[name] = weight.clone().to(torch.bfloat16) - - @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) - def execute_checkpoint_engine(self, method: str, *args, **kwargs): - return getattr(self.checkpoint_engine, method)(*args, **kwargs) + self.received_weights[name] = weight.clone() - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def check_weights(self): if not self.check_allclose: return + + if self.model is None: + local_path = copy_to_local(self.model_config.path) + self.model = AutoModelForCausalLM.from_pretrained(local_path, torch_dtype=torch.bfloat16, device_map="cpu") + for name, weight in self.model.state_dict().items(): assert name in self.received_weights, f"weight {name} not received" - assert torch.allclose(weight, self.received_weights[name]), f"weight {name} not equal" + received = self.received_weights[name] + assert torch.allclose(weight.to(received.device), received), f"weight {name} not equal" self.received_weights.clear() +class MockReplica(RolloutReplica): + async def init_hybrid(self, worker_group: RayWorkerGroup): + """Init hybrid rollout server, rollout engine and training engine(fsdp/megatron) fused in same process. + + Args: + worker_group: RayWorkerGroup, fused workers where training engine(fsdp/megatron) have been initialized. + """ + self.workers = worker_group.workers[ + self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1) + ] + + def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: + """Get rollout worker actor class for colocated and standalone mode.""" + raise NotImplementedError + + async def launch_servers(self): + """Launch http server in each node.""" + raise NotImplementedError + + +class CheckpointEngineWorkerTest(CheckpointEngineWorker): + def __init__(self, rollout_config: RolloutConfig, model_config: HFModelConfig, check_allclose: bool = True) -> None: + server_adapter = MockServerAdapter(rollout_config, model_config, check_allclose) + super().__init__(rollout_config, model_config, server_adapter) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def check_weights(self): + self.server_adapter.check_weights() + + def create_trainer_worker_group( - model_path: str, resource_pool: RayResourcePool, checkpoint_backend: str, checkpoint_kwargs: dict + resource_pool: RayResourcePool, model_config: HFModelConfig, checkpoint_engine_config: CheckpointEngineConfig ) -> RayWorkerGroup: - local_path = copy_to_local(model_path) - model_config = HFModelConfig(path=local_path, use_remove_padding=True) engine_config = FSDPEngineConfig(forward_only=True, fsdp_size=resource_pool.world_size, strategy="fsdp") - trainer_config = TrainingWorkerConfig( model_type="language_model", model_config=model_config, engine_config=engine_config, ) + ray_cls_with_init = RayClassWithInitArgs( cls=ray.remote(TrainingWorkerTest), config=trainer_config, - checkpoint_backend=checkpoint_backend, - checkpoint_kwargs=checkpoint_kwargs, + checkpoint_engine_config=checkpoint_engine_config, ) ray_cls_with_init.update_options( { @@ -112,16 +144,36 @@ def create_trainer_worker_group( return wg -def create_rollout_worker_group( - model_path: str, resource_pool: RayResourcePool, checkpoint_backend: str, checkpoint_kwargs: dict, **kwargs -) -> RayWorkerGroup: +async def create_rollout_worker_group( + resource_pool: RayResourcePool, + model_config: HFModelConfig, + rollout_config: RolloutConfig, + check_allclose: bool = True, +) -> tuple[RayWorkerGroup, list[MockReplica]]: + # create rollout worker group ray_cls_with_init = RayClassWithInitArgs( - cls=ray.remote(RolloutWorkerTest), - model_path=model_path, - checkpoint_backend=checkpoint_backend, - checkpoint_kwargs=checkpoint_kwargs, - device=get_device_name(), - **kwargs, + cls=ray.remote(CheckpointEngineWorkerTest), + model_config=model_config, + rollout_config=rollout_config, + check_allclose=check_allclose, ) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()) - return wg + + # create rollout replicas + rollout_world_size = ( + rollout_config.tensor_model_parallel_size + * rollout_config.data_parallel_size + * rollout_config.pipeline_model_parallel_size + ) + num_replicas = wg.world_size // rollout_world_size + replicas = [] + for replica_rank in range(num_replicas): + replica = MockReplica( + replica_rank=replica_rank, + config=rollout_config, + model_config=model_config, + ) + replicas.append(replica) + await asyncio.gather(*[replica.init_hybrid(wg) for replica in replicas]) + + return wg, replicas diff --git a/tests/experimental/agent_loop/test_basic_agent_loop.py b/tests/experimental/agent_loop/test_basic_agent_loop.py index 6746db10137..7cb55bde48f 100644 --- a/tests/experimental/agent_loop/test_basic_agent_loop.py +++ b/tests/experimental/agent_loop/test_basic_agent_loop.py @@ -22,6 +22,7 @@ from transformers.utils import get_json_schema from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.checkpoint_engine import CheckpointEngineManager from verl.experimental.agent_loop import AgentLoopManager from verl.experimental.agent_loop.agent_loop import get_trajectory_info from verl.protocol import DataProto @@ -348,6 +349,13 @@ def test_tool_agent_with_interaction(init_config): init_config.actor_rollout_ref.rollout.multi_turn.interaction_config_path = interaction_config_path init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 agent_loop_manager = init_agent_loop_manager(init_config) + checkpoint_manager = CheckpointEngineManager( + backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=agent_loop_manager.worker_group, + replicas=agent_loop_manager.rollout_replicas, + ) + checkpoint_manager.sleep_replicas() + checkpoint_manager.update_weights() # =========================== 2. Generate sequences =========================== raw_prompts = [ diff --git a/tests/experimental/agent_loop/test_standalone_rollout.py b/tests/experimental/agent_loop/test_standalone_rollout.py index c17e70bc19f..96b7912045b 100644 --- a/tests/experimental/agent_loop/test_standalone_rollout.py +++ b/tests/experimental/agent_loop/test_standalone_rollout.py @@ -20,6 +20,7 @@ from openai import AsyncOpenAI, OpenAI from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.checkpoint_engine import CheckpointEngineManager from verl.workers.rollout.replica import get_rollout_replica_class @@ -122,12 +123,13 @@ def test_hybrid_rollout_with_ep(init_config): # - offload FSDP model and optimizer, build rollout # - sleep rollout and load FSDP model and optimizer agent_loop_manager = init_agent_loop_manager(init_config) - - # 2. wake up rollout - # - wake_up weights - # - load_weights from FSDP - # - wake_up kv_cache - agent_loop_manager.wake_up() + checkpoint_manager = CheckpointEngineManager( + backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=agent_loop_manager.worker_group, + replicas=agent_loop_manager.rollout_replicas, + ) + checkpoint_manager.sleep_replicas() + checkpoint_manager.update_weights() # 3. test async openai call server_address = agent_loop_manager.server_addresses[0] diff --git a/verl/checkpoint_engine/__init__.py b/verl/checkpoint_engine/__init__.py index 3c20eb35401..4409369e8e8 100644 --- a/verl/checkpoint_engine/__init__.py +++ b/verl/checkpoint_engine/__init__.py @@ -12,9 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import CheckpointEngine, CheckpointEngineRegistry, ColocatedCheckpointEngine, TensorMeta - -__all__ = ["CheckpointEngine", "CheckpointEngineRegistry", "TensorMeta", "ColocatedCheckpointEngine"] +from .base import ( + CheckpointEngine, + CheckpointEngineManager, + CheckpointEngineRegistry, + CheckpointEngineWorker, + ColocatedCheckpointEngine, + TensorMeta, +) + +__all__ = [ + "CheckpointEngine", + "CheckpointEngineRegistry", + "TensorMeta", + "ColocatedCheckpointEngine", + "CheckpointEngineManager", + "CheckpointEngineWorker", +] try: from .nccl_checkpoint_engine import NCCLCheckpointEngine diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py index c10284b1e42..86be0ceeb3f 100644 --- a/verl/checkpoint_engine/base.py +++ b/verl/checkpoint_engine/base.py @@ -11,12 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from abc import ABC, abstractmethod -from typing import Generator, TypedDict +from typing import Any, Generator, TypedDict +import ray import torch +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.utils.distributed import initialize_global_process_group_ray +from verl.utils.ray_utils import auto_await +from verl.workers.config import HFModelConfig, RolloutConfig +from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class + class TensorMeta(TypedDict): name: str @@ -43,6 +52,18 @@ def wrapper(cls: type["CheckpointEngine"]): return wrapper + @classmethod + def get(cls, backend: str) -> type["CheckpointEngine"]: + """Get the checkpoint engine class. + + Args: + backend: The backend of the checkpoint engine. + + Returns: + The checkpoint engine class. + """ + return cls._registry[backend] + @classmethod def new(cls, backend: str, *args, **kwargs) -> "CheckpointEngine": """Create a new checkpoint engine instance. @@ -74,6 +95,69 @@ class CheckpointEngine(ABC): >>> await server_adapter.update_weights(engine.get_weights()) # update weights via cuda ipc """ + @abstractmethod + def prepare(self) -> dict[str, Any]: + """Prepare checkpoint engine before each step send_weights/receive_weights. + + 1. Allocate weight bucket. + 2. [Optional] Register weight bucket for RDMA. + 3. Return metadata to build communication topology: master ip:port, register RDMA description, etc. + + Args: + worker_group: The worker group that the checkpoint engine will be used. + + Returns: + A dictionary that contains the metadata of the worker group. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_topology( + cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict] + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """Build communication topology between all workers. + + Args: + trainer_world_size: The world size of the trainer worker group. + rollout_world_size: The world size of the rollout replica. + metadata: A list of metadata `prepare` from all workers. + + Returns: + A tuple of two dictionaries that contains the communication topology for trainer and rollout worker group. + Each dict value should be a list argument equal to the world size of the worker group to dispatch to + `init_process_group`. + + ``` + world_size = rollout.world_size + trainer.world_size + kwargs = { + "rank": list(range(world_size)), + "world_size": [world_size] * world_size, + "master_metadata": [metadata[0]] * world_size, + } + ``` + """ + raise NotImplementedError + + @abstractmethod + def init_process_group(self, **kwargs): + """Init process group for checkpoint engine. + + Args: + **kwargs: Keyword arguments from `build_topology`. + """ + raise NotImplementedError + + @abstractmethod + def finalize(self): + """Finalize checkpoint engine after each step send_weights/receive_weights. + + 1. Free weight bucket. + 1. [Optional] Deregister weight bucket for RDMA. + 2. [Optional] Destroy process group. + """ + raise NotImplementedError + @abstractmethod async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): """Send the weights of the model. @@ -122,6 +206,22 @@ class ColocatedCheckpointEngine(CheckpointEngine): >>> server_adapter.update_weights(engine.receive_weights()) """ + def __init__(self, bucket_size: int) -> None: + self.bucket_size = bucket_size + + def prepare(self): + raise NotImplementedError + + def init_process_group(self, **kwargs): + raise NotImplementedError + + def finalize(self): + raise NotImplementedError + + @classmethod + def build_topology(cls, *args, **kwargs): + raise NotImplementedError + def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): """Send the weights of the model. @@ -138,3 +238,172 @@ def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: """ yield from self.weights self.weights = None + + +class CheckpointEngineWorker(Worker): + """CheckpointEngineWorker colocated with inference engine's WorkerProc on same GPU. + + Args: + rollout_config: The rollout configuration. + model_config: The model configuration. + server_adapter: The server adapter to update weights. + """ + + def __init__( + self, + rollout_config: RolloutConfig, + model_config: HFModelConfig, + server_adapter: BaseRollout = None, + ) -> None: + self.rollout_config = rollout_config + self.model_config = model_config + + # sglang and trt-llm need device_mesh for internal communication + initialize_global_process_group_ray(timeout_second=None, backend="cpu:gloo") + self.server_adapter: BaseRollout = server_adapter or get_rollout_class( + rollout_config.name, rollout_config.mode + )(config=rollout_config, model_config=model_config, device_mesh=None) + + backend = rollout_config.checkpoint_engine.backend + bucket_size = rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20 + engine_kwargs = rollout_config.checkpoint_engine.engine_kwargs.get(backend, {}) + self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): + weights = self.checkpoint_engine.receive_weights() + await self.server_adapter.update_weights(weights) + + @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) + def execute_checkpoint_engine(self, method: str, *args, **kwargs): + return getattr(self.checkpoint_engine, method)(*args, **kwargs) + + +_worker_cls = ray.remote(CheckpointEngineWorker) + + +class CheckpointEngineManager: + """Checkpoint engine manager to coordinate weight synchronization between trainer and rollout replicas. + + - ME: model engine, FSDP, MCore, VeOmni, export full tensor generator `get_per_tensor_param` + - CE: checkpoint engine, NCCL, NIXL, etc + + In trainer, model engine and checkpoint engine are in same process. + In rollout, checkpoint engine and rollout worker are in separate process, update weights via cuda ipc. + + ``` + ┌────────┬────────┬─────┬────────┐ ┌───────────────────┬───────────────────┐ + │ ┌────┐ │ ┌────┐ │ │ ┌────┐ │ │ Replica 0 │ Replica 1 │ + │ │ ME0│ │ │ ME1│ │ │ │ MEn│ │ ├────┬────┬────┬────┼────┬────┬────┬────┤ + │ └──┬─┘ │ └────┘ │ ... │ └────┘ │ │ 0 │ 1 │ 2 │ 3 │ 0 │ 1 │ 2 │ 3 │ + │ v | | | | └──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┘ + | ┌──┴─┐ │ ┌────┐ │ │ ┌────┐ │ ^ ^ ^ cuda ipc ^ ^ ^ + │ │ CE │ │ │ CE │ │ │ │ CE │ │ ┌──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┐ + │ └──┬─┘ │ └────┘ │ │ └────┘ │ │ CE │ CE │ CE │ CE │ CE │ CE │ CE │ CE | + └────┼───┴────────┴─────┴────────┘ └──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┘ + v | | | | | | | | + └─────────────(nccl/nixl/..)─────────────┴────┴────┴────┴────┴────┴────┴────┘ + ``` + + Args: + backend: The checkpoint engine backend. + trainer: The trainer worker group. + replicas: The list of rollout replicas. + """ + + def __init__( + self, + backend: str, + trainer: RayWorkerGroup, + replicas: list[RolloutReplica], + ) -> None: + self.backend = backend + self.backend_cls = CheckpointEngineRegistry.get(backend) + self.trainer = trainer + self.replicas = replicas + + def build_process_group(self, rollout: RayWorkerGroup): + """Build process group for trainer and rollout replicas.""" + trainer = self.trainer + + # 1. prepare all workers + metadata = ray.get( + trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) + + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) + ) + + # 2. build communication topology between all workers + trainer_kwargs, rollout_kwargs = self.backend_cls.build_topology( + trainer.world_size, rollout.world_size, metadata + ) + for k, v in trainer_kwargs.items(): + assert len(v) == trainer.world_size, f"trainer_kwargs[{k}] must have length of {trainer.world_size}" + for k, v in rollout_kwargs.items(): + assert len(v) == rollout.world_size, f"rollout_kwargs[{k}] must have length of {rollout.world_size}" + + trainer_kwargs["method"] = ["init_process_group"] * trainer.world_size + rollout_kwargs["method"] = ["init_process_group"] * rollout.world_size + + # 3. init process group between all workers + ray.get( + trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) + ) + + def add_replicas(self, replicas: list[RolloutReplica]): + """Add rollout replicas to the manager for elastic scale up, will rebuild process group. + + Args: + replicas: The list of rollout replicas to add. + """ + self.replicas.extend(replicas) + + def remove_replicas(self, replicas: list[RolloutReplica]): + """Remove rollout replicas from the manager for elastic scale down, will rebuild process group. + + Args: + replicas: The list of rollout replicas to remove. + """ + replicas_set = set(replicas) + self.replicas = [r for r in self.replicas if r not in replicas_set] + + @auto_await + async def sleep_replicas(self): + """Sleep all rollout replicas: free weight and kv_cache device memory.""" + # skip sleep replicas for disaggregated rollout + if self.backend != "naive": + return + await asyncio.gather(*[r.sleep() for r in self.replicas]) + + @auto_await + async def update_weights(self): + """Update weights from trainer to rollout replicas.""" + + # 0. update weights for sync training with colocated trainer and rollout + if self.backend == "naive": + ray.get(self.trainer.update_weights()) + return + + # 1. abort and save all unfinished requests for partial rollout + await asyncio.gather(*[r.abort_all_requests() for r in self.replicas]) + + # 2. create a temporay worker group for all replicas + workers = [] + for replica in self.replicas: + workers.extend(replica.workers) + rollout = RayWorkerGroup(worker_handles=workers, ray_cls_with_init=RayClassWithInitArgs(cls=_worker_cls)) + trainer = self.trainer + + # 3. build process group + self.build_process_group(rollout) + + # 4. update weights of all workers + ray.get(trainer.update_weights() + rollout.update_weights()) + + # 5. finalize all workers + ray.get( + trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size) + + rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) + ) + + # 6. resume all unfinished requests for partial rollout + await asyncio.gather(*[r.resume_all_requests() for r in self.replicas]) diff --git a/verl/checkpoint_engine/hccl_checkpoint_engine.py b/verl/checkpoint_engine/hccl_checkpoint_engine.py index 448c2f55aee..eb4c0df0bc3 100644 --- a/verl/checkpoint_engine/hccl_checkpoint_engine.py +++ b/verl/checkpoint_engine/hccl_checkpoint_engine.py @@ -137,7 +137,7 @@ def prepare(self) -> MasterMetadata: else None ) - def finish(self): + def finalize(self): """Destroy the HCCL process group if rebuild_group is True.""" if self.rebuild_group: if self.rank >= 0: @@ -149,6 +149,20 @@ def finish(self): self.send_buf = None self.recv_buf = None + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [metadata[0]] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [metadata[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + def _start_zmq_server(self): self.ip = ray.util.get_node_ip_address().strip("[]") self.zmq_port, self.listen_sock = get_free_port(self.ip) diff --git a/verl/checkpoint_engine/nccl_checkpoint_engine.py b/verl/checkpoint_engine/nccl_checkpoint_engine.py index 9b62682eb15..526bf97347e 100644 --- a/verl/checkpoint_engine/nccl_checkpoint_engine.py +++ b/verl/checkpoint_engine/nccl_checkpoint_engine.py @@ -17,8 +17,11 @@ import time from dataclasses import dataclass from typing import AsyncGenerator, Generator +from unittest.mock import patch + +with patch("importlib.metadata.distributions", return_value=[]): + import cupy as cp -import cupy as cp import ray import ray.util.collective as collective import torch @@ -134,7 +137,7 @@ def prepare(self) -> MasterMetadata: return MasterMetadata(zmq_ip=self.ip, zmq_port=self.listen_port) if self.is_master else None - def finish(self): + def finalize(self): """Destroy the NCCL process group if rebuild_group is True.""" if self.rebuild_group: if self.rank >= 0: @@ -145,6 +148,20 @@ def finish(self): self.send_buf = None self.recv_buf = None + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [metadata[0]] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [metadata[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + def _start_zmq_server(self): self.ip = ray.util.get_node_ip_address().strip("[]") self.listen_port, self.listen_sock = get_free_port(self.ip) diff --git a/verl/checkpoint_engine/nixl_checkpoint_engine.py b/verl/checkpoint_engine/nixl_checkpoint_engine.py index d102409bbcb..01d7c50da6d 100644 --- a/verl/checkpoint_engine/nixl_checkpoint_engine.py +++ b/verl/checkpoint_engine/nixl_checkpoint_engine.py @@ -19,8 +19,11 @@ from collections import defaultdict, deque from dataclasses import dataclass from typing import AsyncGenerator, Generator +from unittest.mock import patch + +with patch("importlib.metadata.distributions", return_value=[]): + import cupy as cp -import cupy as cp import nixl._api as nixl_api import nixl._bindings as nixl_bindings import ray @@ -274,6 +277,25 @@ def prepare(self) -> NixlAgentMetadata: return self.agent.get_agent_metadata() + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "method": ["init_process_group"] * trainer_world_size, + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "prev_agent_metadata": [None] * trainer_world_size, + "next_agent_metadata": [metadata[-rollout_world_size]] + [None] * (trainer_world_size - 1), + } + + rollout_kwargs = { + "method": ["init_process_group"] * rollout_world_size, + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "prev_agent_metadata": [metadata[0]] + metadata[-rollout_world_size:-1], + "next_agent_metadata": metadata[-rollout_world_size + 1 :] + [None], + } + return trainer_kwargs, rollout_kwargs + def init_process_group( self, rank: int, world_size: int, prev_agent_metadata: NixlAgentMetadata, next_agent_metadata: NixlAgentMetadata ): @@ -316,7 +338,7 @@ def init_process_group( f"prev_agent: {self.prev_agent}, next_agent: {self.next_agent}" ) - def finish(self): + def finalize(self): """Cleanup communication with the previous and next agent, and deregister the memory.""" if self.prev_agent: self.agent.remove_remote_agent(self.prev_agent) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index db4332fd565..f2be85d5b4e 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -879,10 +879,6 @@ def __init__( self._initialize_llm_servers(rollout_resource_pool) self._init_agent_loop_workers() - # Initially we're in sleep mode. - if self.config.actor_rollout_ref.rollout.free_cache_engine: - self.sleep() - def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool): rollout_world_size = ( self.config.actor_rollout_ref.rollout.tensor_model_parallel_size @@ -958,9 +954,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: DataProto: Output batch. """ - # Fix for Issue #4147: Always call wake_up() to ensure weight sync - # The wake_up()/sleep() methods internally check free_cache_engine - self.wake_up() + # TODO: move reward_model_manager out of agent_loop manager if self.reward_model_manager: self.reward_model_manager.wake_up() @@ -972,8 +966,6 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: ] ) output = DataProto.concat(outputs) - # Fix for Issue #4147: Always call sleep() to ensure proper cleanup - self.sleep() if self.reward_model_manager: self.reward_model_manager.sleep() @@ -1011,14 +1003,6 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data return timing - def wake_up(self): - """Wake up all rollout replica instances.""" - self._run_all([replica.wake_up() for replica in self.rollout_replicas]) - - def sleep(self): - """Sleep all rollout replica instances.""" - self._run_all([replica.sleep() for replica in self.rollout_replicas]) - def clear_kv_cache(self): """Clear all rollout kv cache, but don`t sleep.""" self._run_all([replica.clear_kv_cache() for replica in self.rollout_replicas]) diff --git a/verl/experimental/transfer_queue/agent_loop.py b/verl/experimental/transfer_queue/agent_loop.py index 4e44c2ff9c0..0887b4600e9 100644 --- a/verl/experimental/transfer_queue/agent_loop.py +++ b/verl/experimental/transfer_queue/agent_loop.py @@ -31,8 +31,6 @@ def generate_sequences(self, prompts: BatchMeta) -> BatchMeta: BatchMeta: Output batch metadata. """ - if self.config.actor_rollout_ref.rollout.free_cache_engine: - self.wake_up() if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine: self.reward_model_manager.wake_up() @@ -44,8 +42,6 @@ def generate_sequences(self, prompts: BatchMeta) -> BatchMeta: ] ) output = BatchMeta.concat(outputs) - if self.config.actor_rollout_ref.rollout.free_cache_engine: - self.sleep() if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine: self.reward_model_manager.sleep() diff --git a/verl/experimental/transfer_queue/ray_trainer.py b/verl/experimental/transfer_queue/ray_trainer.py index b29682dba31..1f2be802b0f 100644 --- a/verl/experimental/transfer_queue/ray_trainer.py +++ b/verl/experimental/transfer_queue/ray_trainer.py @@ -24,7 +24,6 @@ import os import uuid from collections import defaultdict -from dataclasses import dataclass, field from pprint import pprint from typing import Any, Optional @@ -47,8 +46,9 @@ ) from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager from verl.experimental.dataset.sampler import AbstractCurriculumSampler -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.config import AlgoConfig from verl.trainer.ppo import core_algos @@ -72,63 +72,6 @@ from verl.utils.transferqueue_utils import create_transferqueue_client, get_transferqueue_client, tqbridge -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - """Create Ray resource pools for distributed training. - - Initializes resource pools based on the resource pool specification, - with each pool managing GPU resources across multiple nodes. - For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. - For Megatron backend, uses max_colocate_count>1 for different models. - """ - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 - # that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray._private.state.available_resources_per_node() - node_available_gpus = { - node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) - for node, node_info in node_available_resources.items() - } - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - @tqbridge(put_data=False) def compute_reward_decorated(data, reward_fn): return compute_reward(data, reward_fn) @@ -912,6 +855,15 @@ def init_workers(self): rm_resource_pool=rm_resource_pool, ) + self.checkpoint_manager = CheckpointEngineManager( + backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + self.checkpoint_manager.sleep_replicas() + # TODO (TQ): initialize tq during worker init when enable TQ switch is stable self.async_rollout_manager.create_transferqueue_client_for_workers() @@ -1186,8 +1138,9 @@ def fit(self): self.global_steps = 0 - # load checkpoint before doing anything + # load checkpoint and update weights before doing anything self._load_checkpoint() + self.checkpoint_manager.update_weights() # perform validation before training # currently, we only support validation using the reward_function. @@ -1258,6 +1211,7 @@ def fit(self): gen_output_meta = self.actor_rollout_wg.generate_sequences(gen_meta) else: gen_output_meta = self.async_rollout_manager.generate_sequences(gen_meta) + self.checkpoint_manager.sleep_replicas() timing_raw.update(gen_output_meta.extra_info["timing"]) gen_output_meta.extra_info.pop("timing", None) @@ -1561,6 +1515,11 @@ def fit(self): actor_output_meta = self.actor_rollout_wg.update_actor(update_actor_meta) batch_meta = batch_meta.union(actor_output_meta) + + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights() + actor_output_metrics = reduce_metrics(actor_output_meta.extra_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/single_controller/ray/__init__.py b/verl/single_controller/ray/__init__.py index 6679166e370..b60291d23ac 100644 --- a/verl/single_controller/ray/__init__.py +++ b/verl/single_controller/ray/__init__.py @@ -16,6 +16,7 @@ RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, + ResourcePoolManager, SubRayResourcePool, create_colocated_worker_cls, create_colocated_worker_cls_fused, @@ -26,6 +27,7 @@ "RayResourcePool", "SubRayResourcePool", "RayWorkerGroup", + "ResourcePoolManager", "create_colocated_worker_cls", "create_colocated_worker_cls_fused", ] diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 7bc069239d2..d632be4f6fb 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -16,6 +16,7 @@ import os import socket from copy import deepcopy +from dataclasses import dataclass, field from typing import Any, Optional import numpy as np @@ -163,6 +164,63 @@ def world_size(self): return self.subgroup_world_size +@dataclass +class ResourcePoolManager: + """ + Define a resource pool specification. Resource pool will be initialized first. + """ + + resource_pool_spec: dict[str, list[int]] + mapping: dict[int, str] + resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) + + def create_resource_pool(self): + """Create Ray resource pools for distributed training. + + Initializes resource pools based on the resource pool specification, + with each pool managing GPU resources across multiple nodes. + For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. + For Megatron backend, uses max_colocate_count>1 for different models. + """ + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool + # For FSDP backend, using max_colocate_count=3: actor_critic_ref, rollout, reward model (optional) + # For Megatron backend, we recommend using max_colocate_count>1 + # that can utilize different WorkerGroup for differnt models + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=3, name_prefix=resource_pool_name + ) + self.resource_pool_dict[resource_pool_name] = resource_pool + + self._check_resource_available() + + def get_resource_pool(self, role) -> RayResourcePool: + """Get the resource pool of the worker_cls""" + return self.resource_pool_dict[self.mapping[role]] + + def get_n_gpus(self) -> int: + """Get the number of gpus in this cluster.""" + return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + + def _check_resource_available(self): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray._private.state.available_resources_per_node() + node_available_gpus = { + node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + for node, node_info in node_available_resources.items() + } + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum( + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) + if total_available_gpus < total_required_gpus: + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) + + def extract_pg_from_exist( resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool ) -> list: @@ -376,7 +434,7 @@ def __init__( self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix self._ray_wait_register_center_timeout = ray_wait_register_center_timeout # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker. - self.fused_worker_used = ray_cls_with_init.fused_worker_used + self.fused_worker_used = False if ray_cls_with_init is None else ray_cls_with_init.fused_worker_used # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to # this WorkerGroup. self.sub_cls_name = "" @@ -435,7 +493,7 @@ def _init_with_detached_workers(self, worker_names, worker_handles): # https://github.com/ray-project/ray/pull/45699 workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names] self._workers = workers - self._world_size = len(worker_names) + self._world_size = len(workers) def _get_master_addr_port(self, pg, bundle_index=0): """Get master addr and port for this worker group""" diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 57de81c74d9..50188e87a0a 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -268,7 +268,11 @@ actor_rollout_ref: _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null - update_weights_bucket_megabytes: 2048 + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 2048 + engine_kwargs: {} trace: _target_: verl.workers.config.TraceConfig backend: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 1114171680e..945351618d1 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -259,7 +259,11 @@ actor_rollout_ref: _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null - update_weights_bucket_megabytes: 2048 + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 2048 + engine_kwargs: {} trace: _target_: verl.workers.config.TraceConfig backend: null diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index c4bbf8c52a7..2a520cf1186 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -241,19 +241,31 @@ agent: # Class name of the custom async server class (e.g. AsyncvLLMServer) name: null -# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations. -# This parameter controls the maximum payload size for a single weight update request. -# Reference: https://github.com/volcengine/verl/pull/2418 -# Currently only supported in SGLang rollout implementations -# Larger values may improve throughput but increase memory overhead -# Detailed performance comparison: -# https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720 -# Default value (512MB) is optimized for typical GPU memory configurations -# For the best performance of `rebuild_cuda_tensor`, it is recommended to: -# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` -# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` -# when using Tensor Parallelism (TP) >= 8. -update_weights_bucket_megabytes: 2048 +# Checkpoint Engine config for update weights from trainer to rollout +checkpoint_engine: + + # Target class for checkpoint engine config + _target_: verl.workers.config.CheckpointEngineConfig + + # Backend for checkpoint engine: naive, nccl, nixl, hccl + backend: naive + + # Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations. + # This parameter controls the maximum payload size for a single weight update request. + # Reference: https://github.com/volcengine/verl/pull/2418 + # Currently only supported in SGLang rollout implementations + # Larger values may improve throughput but increase memory overhead + # Detailed performance comparison: + # https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/169#issuecomment-3070686720 + # Default value (512MB) is optimized for typical GPU memory configurations + # For the best performance of `rebuild_cuda_tensor`, it is recommended to: + # 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` + # 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` + # when using Tensor Parallelism (TP) >= 8. + update_weights_bucket_megabytes: 2048 + + # Additional keyword arguments to pass to the checkpoint engine constructor + engine_kwargs: {} # trace rollout data trace: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0fb205e98e3..faa6438c203 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -23,7 +23,6 @@ import uuid from collections import defaultdict from copy import deepcopy -from dataclasses import dataclass, field from pprint import pprint from typing import Any, Optional @@ -36,9 +35,10 @@ from tqdm import tqdm from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.config import AlgoConfig from verl.trainer.ppo import core_algos @@ -67,63 +67,6 @@ from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - """Create Ray resource pools for distributed training. - - Initializes resource pools based on the resource pool specification, - with each pool managing GPU resources across multiple nodes. - For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. - For Megatron backend, uses max_colocate_count>1 for different models. - """ - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, using max_colocate_count=3: actor_critic_ref, rollout, reward model (optional) - # For Megatron backend, we recommend using max_colocate_count>1 - # that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=3, name_prefix=resource_pool_name - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray._private.state.available_resources_per_node() - node_available_gpus = { - node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) - for node, node_info in node_available_resources.items() - } - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): """Apply KL penalty to the token-level rewards. @@ -973,6 +916,15 @@ def init_workers(self): rm_resource_pool=rm_resource_pool, ) + self.checkpoint_manager = CheckpointEngineManager( + backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + self.checkpoint_manager.sleep_replicas() + def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -1314,6 +1266,7 @@ def _update_actor(self, batch: DataProto) -> DataProto: actor_output = DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) else: actor_output = self.actor_rollout_wg.update_actor(batch) + return actor_output def _update_critic(self, batch: DataProto) -> DataProto: @@ -1366,8 +1319,9 @@ def fit(self): self.global_steps = 0 - # load checkpoint before doing anything + # load checkpoint and update weights before doing anything self._load_checkpoint() + self.checkpoint_manager.update_weights() current_epoch = self.global_steps // len(self.train_dataloader) @@ -1440,6 +1394,7 @@ def fit(self): if curr_step_profile: self.async_rollout_manager.start_profile() gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() if curr_step_profile: self.async_rollout_manager.stop_profile() @@ -1459,6 +1414,7 @@ def fit(self): if curr_step_profile: self.async_rollout_manager.start_profile() gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + self.checkpoint_manager.sleep_replicas() if curr_step_profile: self.async_rollout_manager.stop_profile() batch = batch.union(gen_baseline_output) @@ -1641,6 +1597,11 @@ def fit(self): # update actor with marked_timer("update_actor", timing_raw, color="red"): actor_output = self._update_actor(batch) + + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights() + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 7c0a3485ef1..17135584edf 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -73,18 +73,18 @@ def destroy_global_process_group(): torch.distributed.destroy_process_group() -def initialize_global_process_group_ray(timeout_second=None): +def initialize_global_process_group_ray(timeout_second=None, backend=None): # in current ray environment, LOCAL_RANK is always zero. import torch.distributed timeout = timedelta(seconds=timeout_second) if timeout_second is not None else None - + backend = backend or f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}" if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) torch.distributed.init_process_group( - backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + backend=backend, rank=rank, world_size=world_size, timeout=timeout, diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py index b7ce7becd62..5ba20649365 100644 --- a/verl/utils/ray_utils.py +++ b/verl/utils/ray_utils.py @@ -17,6 +17,8 @@ import asyncio import concurrent.futures +import functools +import inspect import os from typing import Any, Optional @@ -90,3 +92,31 @@ def get_event_loop(): asyncio.set_event_loop(loop) return loop + + +def auto_await(func): + """Auto await a coroutine function. + + If the function is called in an async context (with a running event loop), + it will return the coroutine object. Otherwise, it will block the current thread + and run the coroutine until completion. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + coro = func(*args, **kwargs) + + if not inspect.iscoroutine(coro): + return coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + return coro + else: + return asyncio.run(coro) + + return wrapper diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index d58da777b7e..3b4e7c121a3 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -30,6 +30,7 @@ "ServerConfig", "PrometheusConfig", "RolloutConfig", + "CheckpointEngineConfig", ] @@ -117,6 +118,20 @@ class PrometheusConfig(BaseConfig): served_model_name: Optional[str] = None +@dataclass +class CheckpointEngineConfig(BaseConfig): + """ + Configuration for checkpoint engine to update weights from trainer to rollout + """ + + # Backend for checkpoint engine: naive, nccl, nixl, hccl + backend: Optional[str] = MISSING + # Bucket size in MB to transfer multiple weights at one time + update_weights_bucket_megabytes: int = 2048 + # Additional keyword arguments for checkpoint engine + engine_kwargs: dict = field(default_factory=dict) + + @dataclass class RolloutConfig(BaseConfig): _mutable_fields = {"max_model_len", "load_format"} @@ -188,7 +203,8 @@ class RolloutConfig(BaseConfig): # Extension point for custom configurations custom: Optional[dict] = None - update_weights_bucket_megabytes: int = 512 + # Checkpoint Engine config for update weights from trainer to rollout + checkpoint_engine: CheckpointEngineConfig = field(default_factory=CheckpointEngineConfig) skip_rollout: bool = False diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 966a8c30e6d..235b7b42009 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -23,11 +23,12 @@ from tensordict import NonTensorData, TensorDict from torch.distributed.device_mesh import init_device_mesh +from verl.checkpoint_engine import CheckpointEngineRegistry from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils import tensordict_utils as tu from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import get_device_name, get_torch_device, set_expandable_segments +from verl.utils.device import get_device_name, set_expandable_segments from verl.utils.distributed import initialize_global_process_group_ray from verl.utils.flops_counter import FlopsCounter from verl.utils.memory_utils import aggressive_empty_cache @@ -484,6 +485,7 @@ def init_model(self): if "rollout" in self.role: rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + # TODO: move rollout_device_mesh into ServerAdapter # 3.1 build rollout device mesh (sglang need only) infer_tp = rollout_config.tensor_model_parallel_size * rollout_config.data_parallel_size infer_pp = rollout_config.pipeline_model_parallel_size @@ -496,14 +498,7 @@ def init_model(self): get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] ) - # 3.2 init trainer and rollout random states - self.torch_random_states = get_torch_device().get_rng_state() - gen_dp_rank = rollout_device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - # 3.3 initialize rollout engine + # 3.2 initialize rollout engine rollout_cls: type[BaseRollout] = get_rollout_class(rollout_config.name, rollout_config.mode) self.rollout = rollout_cls( config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh @@ -514,6 +509,16 @@ def init_model(self): self.layered_summon = self.config.rollout.get("layered_summon", False) self.peft_merge: bool = model_config.lora.get("merge", False) + # 4. build checkpoint engine + if "actor" in self.role: + checkpoint_engine_config = omega_conf_to_dataclass(self.config.rollout.checkpoint_engine) + backend = checkpoint_engine_config.backend + bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20 + engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {}) + if torch.distributed.get_rank() == 0 and backend in ["nccl", "hccl"]: + engine_kwargs["is_master"] = True + self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: @@ -542,28 +547,24 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to assert "actor" in self.role, "save_checkpoint only support actor role" self.actor.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def sleep(self): - """Context switch from rollout mode to trainer mode.""" - if self.config.rollout.free_cache_engine: - log_gpu_memory_usage("Before rollout offload", logger=logger) - await self.rollout.release() - log_gpu_memory_usage("After rollout offload", logger=logger) + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): + """Update weights from trainer to rollout. - # add empty cache after each compute - aggressive_empty_cache(force_sync=True) - set_expandable_segments(True) + 1. For sync training with colocated trainer and rollout, update rollout directly from model engine. + - before update_weights: rollout should be in sleep mode. + - after update_weights: rollout should be in wake_up mode. + 2. For async training with disaggregated trainer and rollout, send_weights only by checkpoint engine. + """ + assert self.checkpoint_engine is not None - # restore random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) + # 0. send_weights only for async training with disaggregated trainer and rollout + if self.config.rollout.checkpoint_engine.backend != "naive": + per_tensor_param, _ = self.engine.get_per_tensor_param() + await self.checkpoint_engine.send_weights(per_tensor_param) + return - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def wake_up(self): - """Context switch trainer mode to rollout mode.""" - aggressive_empty_cache(force_sync=True) set_expandable_segments(False) - # 1. resume weights and update weights if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) @@ -603,6 +604,16 @@ async def wake_up(self): log_gpu_memory_usage("After resume kv_cache", logger=logger) self.base_sync_done = True - # important: need to manually set the random states of each tp to be identical. - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) + set_expandable_segments(True) + + @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) + def execute_checkpoint_engine(self, method: str, *args, **kwargs): + """Execute checkpoint engine method. + + Args: + method (str): Checkpoint engine method name. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + """ + return getattr(self.checkpoint_engine, method)(*args, **kwargs) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 659936891af..a5e72f84f92 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -644,13 +644,6 @@ def _build_rollout(self, trust_remote_code=False): "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect ) - # 3. init trainer and rollout random states - self.torch_random_states = get_torch_device().get_rng_state() - gen_dp_rank = rollout_device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - # 4. build rollout model log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( @@ -760,28 +753,8 @@ async def rollout_mode(self): log_gpu_memory_usage("After resume kv_cache", logger=logger) self.base_sync_done = True - # important: need to manually set the random states of each tp to be identical. - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - async def trainer_mode(self): - """Context switch hybridengine to trainer mode.""" - if self.config.rollout.free_cache_engine: - log_gpu_memory_usage("Before rollout offload", logger=logger) - await self.rollout.release() - log_gpu_memory_usage("After rollout offload", logger=logger) - - self.actor_module_fsdp.train() - - # add empty cache after each compute - aggressive_empty_cache(force_sync=True) - set_expandable_segments(True) - # restore random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from verl.workers.actor import DataParallelPPOActor @@ -2010,12 +1983,7 @@ def compute_rm_score(self, data: DataProto): # ================================= Async related workers ================================= class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def wake_up(self): + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): await self.rollout_mode() return True - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def sleep(self): - await self.trainer_mode() - return True diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index aa7613fbc78..1323afbce58 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -531,13 +531,6 @@ def _build_rollout(self, trust_remote_code=False): "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect ) - # 3. init trainer and rollout random states - self.torch_random_states = get_torch_device().get_rng_state() - gen_dp_rank = rollout_device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - # 4. build rollout model log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( @@ -741,31 +734,7 @@ async def rollout_mode(self): if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["kv_cache"]) - # important: need to manually set the random states of each tp to be identical. - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - async def trainer_mode(self): - """Context switch hybridengine to trainer mode.""" - if self.config.rollout.free_cache_engine: - log_gpu_memory_usage("Before rollout offload", logger=logger) - await self.rollout.release() - log_gpu_memory_usage("After rollout offload", logger=logger) - - for model in self.actor.actor_module: - model.train() - # add empty cache after each compute - aggressive_empty_cache(force_sync=True) - - # FIXME(@wuxibin): megatron+sglang failed with `expandable_segments:True` in ci, - # can't reproduce it in dev environment, temporary disable it. - # https://github.com/volcengine/verl/actions/runs/17382936845/job/49344264323?pr=3285 - if os.environ.get("MEGATRON_CI_DISABLE_EXPANDABLE_SEGMENTS", "0") == "0": - set_expandable_segments(True) - - # restore random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) + set_expandable_segments(True) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="update_actor", logger=logger) @@ -1009,16 +978,11 @@ def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def wake_up(self): + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): await self.rollout_mode() return True - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def sleep(self): - await self.trainer_mode() - return True - class CriticWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config: McoreCriticConfig): diff --git a/verl/workers/rollout/__init__.py b/verl/workers/rollout/__init__.py index 5d9263c9c56..f6bd6c28b77 100644 --- a/verl/workers/rollout/__init__.py +++ b/verl/workers/rollout/__init__.py @@ -15,5 +15,6 @@ from .base import BaseRollout, get_rollout_class from .hf_rollout import HFRollout from .naive import NaiveRollout +from .replica import RolloutReplica -__all__ = ["BaseRollout", "NaiveRollout", "HFRollout", "get_rollout_class"] +__all__ = ["BaseRollout", "NaiveRollout", "HFRollout", "get_rollout_class", "RolloutReplica"] diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index e7e800622a6..748ab436177 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -22,8 +22,7 @@ from pydantic import BaseModel from ray.actor import ActorHandle -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup -from verl.trainer.ppo.ray_trainer import RayResourcePool, ResourcePoolManager +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, ResourcePoolManager from verl.utils.config import omega_conf_to_dataclass from verl.workers.config import HFModelConfig, RolloutConfig @@ -229,6 +228,18 @@ async def sleep(self): """Sleep each rollout server.""" await asyncio.gather(*[server.sleep.remote() for server in self.servers]) + async def abort_all_requests(self): + """Partial rollout: abort and save all unfinished requests in each rollout server.""" + # TODO(wuxibin) + # await asyncio.gather(*[server.abort_all_requests.remote() for server in self.servers]) + print(f"abort all requests in rollout replica {self.replica_rank}") + + async def resume_all_requests(self): + """Partial rollout: resume all unfinished requests in each rollout server.""" + # TODO(wuxibin) + # await asyncio.gather(*[server.resume_all_requests.remote() for server in self.servers]) + print(f"resume all requests in rollout replica {self.replica_rank}") + async def clear_kv_cache(self): """reset kv cache in each rollout server.""" await asyncio.gather(*[server.clear_kv_cache.remote() for server in self.servers]) diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 1e176bafd7b..45ceb07548f 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -281,9 +281,12 @@ async def launch_server(self, master_address: str = None, master_port: int = Non self.tokenizer_manager.server_status = ServerStatus.Up async def wake_up(self): + if self.node_rank != 0: + return + if self.rollout_mode == RolloutMode.HYBRID: - # Call all workers to switch between trainer mode and rollout mode. - await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers]) + # In hybrid mode, rollout is wake up in `update_weights` + raise ValueError(f"wake_up not support rollout_mode {self.rollout_mode}") elif self.rollout_mode == RolloutMode.COLOCATED: # Directly call engine to wake up without sync weights. obj = ResumeMemoryOccupationReqInput(tags=["kv_cache", "weights"]) @@ -293,8 +296,12 @@ async def wake_up(self): logger.info("skip wake_up in standalone mode") async def sleep(self): + if self.node_rank != 0 or not self.config.free_cache_engine: + return + if self.rollout_mode == RolloutMode.HYBRID: - await asyncio.gather(*[worker.sleep.remote() for worker in self.workers]) + obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache", "weights"]) + await self.tokenizer_manager.release_memory_occupation(obj, None) elif self.rollout_mode == RolloutMode.COLOCATED: obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache", "weights"]) await self.tokenizer_manager.release_memory_occupation(obj, None) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 24da85f20c5..2be15fc5b05 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -31,7 +31,7 @@ set_ulimit, ) from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from verl.utils.net_utils import is_valid_ipv6_address from verl.workers.config import HFModelConfig, RolloutConfig @@ -129,6 +129,21 @@ async def _init_server_adapter(self): if self._engine is not None: return + # device_mesh is needed to gather cuda ipc handle to update weights + if self.device_mesh is None: + assert torch.distributed.is_initialized(), "torch distributed must be initialized" + infer_tp = self.config.tensor_model_parallel_size * self.config.data_parallel_size + infer_pp = self.config.pipeline_model_parallel_size + infer_world_size = infer_tp * infer_pp + dp = torch.distributed.get_world_size() // infer_world_size + self.device_mesh = init_device_mesh( + "cpu", mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] + ) + + # Only init http server adapter in tp rank 0 + if self.device_mesh["infer_tp"].get_local_rank() != 0: + return + # Lazy init http server adapter because http server is launched after hybrid engine. self.server_actor = ray.get_actor(f"sglang_server_{self.replica_rank}_{self.node_rank}") server_address, server_port = await self.server_actor.get_server_address.remote() @@ -151,14 +166,14 @@ async def resume(self, tags: list[str]): Args: tag: weights or kv_cache. """ + await self._init_server_adapter() if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine: - await self._init_server_adapter() await self._engine.resume_memory_occupation(tags=tags) async def release(self): """Release weights and kv cache in GPU memory.""" + await self._init_server_adapter() if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine: - await self._init_server_adapter() await self._engine.release_memory_occupation(tags=["kv_cache", "weights"]) async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs): @@ -174,10 +189,9 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 """ - if self.device_mesh["infer_tp"].get_local_rank() == 0: - await self._init_server_adapter() + await self._init_server_adapter() - update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 + update_weights_bucket_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name @@ -190,7 +204,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None else: weights = weights - for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): + async for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): await sgl_update_weights( engine=self._engine, params_batch=params_batch, diff --git a/verl/workers/rollout/sglang_rollout/utils.py b/verl/workers/rollout/sglang_rollout/utils.py index f64bf63b81f..9cc66c0070c 100644 --- a/verl/workers/rollout/sglang_rollout/utils.py +++ b/verl/workers/rollout/sglang_rollout/utils.py @@ -21,6 +21,7 @@ import torch.distributed as dist from verl.utils.device import get_device_name +from verl.workers.rollout.utils import ensure_async_iterator def broadcast_pyobj( @@ -68,7 +69,7 @@ def broadcast_pyobj( return data -def get_named_tensor_buckets( +async def get_named_tensor_buckets( iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int ) -> Iterator[list[tuple[str, torch.Tensor]]]: """ @@ -93,15 +94,15 @@ def get_named_tensor_buckets( current_bucket = [] current_size = 0 - for name, tensor in iterable: + async for name, tensor in ensure_async_iterator(iterable): tensor_size = tensor.element_size() * tensor.numel() if current_size + tensor_size > bucket_bytes: if current_bucket: yield current_bucket - current_bucket = [(name, tensor)] + current_bucket = [(name, tensor.clone())] current_size = tensor_size else: - current_bucket.append((name, tensor)) + current_bucket.append((name, tensor.clone())) current_size += tensor_size if current_bucket: diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 791c14bd1ac..6178f779182 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -182,16 +182,19 @@ async def generate( async def wake_up(self): if self.rollout_mode == RolloutMode.HYBRID: - # Call all workers to switch between trainer mode and rollout mode. - await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers]) - elif self.rollout_mode == RolloutMode.COLOCATED: + # In hybrid mode, rollout is wake up in `update_weights` + raise ValueError(f"wake_up not support rollout_mode {self.rollout_mode}") + if self.rollout_mode == RolloutMode.COLOCATED: await self.llm.resume(tags=ServerAdapter.get_full_tags()) elif self.rollout_mode == RolloutMode.STANDALONE: logger.info("skip wake_up in standalone mode") async def sleep(self): + if not self.config.free_cache_engine: + return + if self.rollout_mode == RolloutMode.HYBRID: - await asyncio.gather(*[worker.sleep.remote() for worker in self.workers]) + await self.llm.release(tags=ServerAdapter.get_full_tags()) elif self.rollout_mode == RolloutMode.COLOCATED: await self.llm.release(tags=ServerAdapter.get_full_tags()) elif self.rollout_mode == RolloutMode.STANDALONE: diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index f95980b24f6..459c35642ce 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -407,7 +407,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None if self.is_leader_rank: await self._init_server_adapter() - total_available_bytes = int(self.config.update_weights_bucket_megabytes) * 1024 * 1024 + total_available_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) * 1024 * 1024 try: device_uuid = get_device_uuid(self.gpu_id) diff --git a/verl/workers/rollout/utils.py b/verl/workers/rollout/utils.py index 16dcfc4a5c9..246ed3896b1 100644 --- a/verl/workers/rollout/utils.py +++ b/verl/workers/rollout/utils.py @@ -56,3 +56,13 @@ async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5) logger.info(f"HTTP server started on port {server_port}") return server_port, server_task + + +async def ensure_async_iterator(iterable): + """Convert an iterable to an async iterator.""" + if hasattr(iterable, "__aiter__"): + async for item in iterable: + yield item + else: + for item in iterable: + yield item diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 196c72bc378..56e3efd5475 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -537,25 +537,28 @@ async def generate( ) async def wake_up(self): + if self.node_rank != 0: + return + if self.rollout_mode == RolloutMode.HYBRID: - # Call all workers to switch between trainer mode and rollout mode. - await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers]) + # In hybrid mode, rollout is wake up in `update_weights` + raise ValueError(f"wake_up not support rollout_mode {self.rollout_mode}") elif self.rollout_mode == RolloutMode.COLOCATED: # Directly call engine to wake up without sync weights. - if self.node_rank == 0: - await self.engine.wake_up(tags=["kv_cache", "weights"]) + await self.engine.wake_up(tags=["kv_cache", "weights"]) + await self.engine.reset_prefix_cache() elif self.rollout_mode == RolloutMode.STANDALONE: logger.info("skip wake_up in standalone mode") async def sleep(self): + if self.node_rank != 0 or not self.config.free_cache_engine: + return + if self.rollout_mode == RolloutMode.HYBRID: - if self.node_rank == 0: - await self.engine.reset_prefix_cache() - await asyncio.gather(*[worker.sleep.remote() for worker in self.workers]) + # Don't use engine.sleep(level=2) here + await self.engine.collective_rpc("sleep", kwargs={"level": 2}) elif self.rollout_mode == RolloutMode.COLOCATED: - if self.node_rank == 0: - await self.engine.reset_prefix_cache() - await self.engine.sleep(level=1) + await self.engine.sleep(level=1) elif self.rollout_mode == RolloutMode.STANDALONE: logger.info("skip sleep in standalone mode") diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index ebbb6e19e48..4a767c4b684 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -45,6 +45,7 @@ from verl.utils.torch_dtypes import PrecisionType from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout +from verl.workers.rollout.utils import ensure_async_iterator from verl.workers.rollout.vllm_rollout.utils import TensorMetadata, get_device_uuid logger = logging.getLogger(__file__) @@ -152,7 +153,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None ) # build cuda ipc buffer - bucket_size_mb = self.config.update_weights_bucket_megabytes + bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes bucket_size = int(bucket_size_mb) << 20 buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:0") handle = reduce_tensor(buffer) @@ -165,7 +166,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None offset = 0 bucket_meta: dict[str, TensorMetadata] = {} dtype = PrecisionType.to_dtype(self.config.dtype) - for name, weight in weights: + async for name, weight in ensure_async_iterator(weights): # model parameters are in fp32 full precision weight = weight.to(dtype, non_blocking=True) @@ -205,6 +206,10 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None if future is not None: await future + # reset prefix cache after updating weights + if self.rollout_rank == 0: + await self.server_handle.clear_kv_cache.remote() + if self.replica_rank == 0 and self.rollout_rank == 0: logger.info(f"update_weights done, time cost: {time.time() - start_time:.2f}s")