Skip to content

Commit c2ed4a1

Browse files
committed
[ray] feat: use get_device_name() for automatic device detection instead of by parameter passing
Signed-off-by: jianjunzhong <[email protected]>
1 parent 0321478 commit c2ed4a1

File tree

2 files changed

+25
-41
lines changed

2 files changed

+25
-41
lines changed

tests/single_controller/test_split_resource_pool.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,14 @@ def test_split_resource_pool_with_split_size():
5151
ray.init()
5252
# assume we have 2 nodes, with 4 GPUs each
5353
global_resource_pool = RayResourcePool(process_on_nodes=[4, 4])
54-
global_resource_pool.get_placement_groups(device_name=get_device_name())
54+
global_resource_pool.get_placement_groups()
5555

5656
# first 4 gpus for actor_1, last 4 gpus for actor_2
5757
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(resource_pool=global_resource_pool, split_size=4)
5858
actor_cls_1 = RayClassWithInitArgs(cls=Actor, worker_id=0)
5959
actor_cls_2 = RayClassWithInitArgs(cls=Actor, worker_id=100)
60-
actor_worker_1 = RayWorkerGroup(
61-
resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1, device_name=get_device_name()
62-
)
63-
actor_worker_2 = RayWorkerGroup(
64-
resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2, device_name=get_device_name()
65-
)
60+
actor_worker_1 = RayWorkerGroup(resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1)
61+
actor_worker_2 = RayWorkerGroup(resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2)
6662
assert actor_worker_1.world_size == 4
6763
assert actor_worker_2.world_size == 4
6864

@@ -79,7 +75,7 @@ def test_split_resource_pool_with_split_size_list():
7975
ray.init()
8076
# assume we have 4 nodes, with 2 GPUs each
8177
global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2])
82-
global_resource_pool.get_placement_groups(device_name=get_device_name())
78+
global_resource_pool.get_placement_groups()
8379

8480
# first 2 gpus for actor_1, last 6 gpus for actor_2
8581
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(
@@ -88,12 +84,8 @@ def test_split_resource_pool_with_split_size_list():
8884
)
8985
actor_cls_1 = RayClassWithInitArgs(cls=Actor, worker_id=0)
9086
actor_cls_2 = RayClassWithInitArgs(cls=Actor, worker_id=100)
91-
actor_worker_1 = RayWorkerGroup(
92-
resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1, device_name=get_device_name()
93-
)
94-
actor_worker_2 = RayWorkerGroup(
95-
resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2, device_name=get_device_name()
96-
)
87+
actor_worker_1 = RayWorkerGroup(resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1)
88+
actor_worker_2 = RayWorkerGroup(resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2)
9789
assert actor_worker_1.world_size == 2
9890
assert actor_worker_2.world_size == 6
9991

@@ -113,7 +105,7 @@ def test_split_resource_pool_with_split_size_list_cross_nodes():
113105
ray.init()
114106
# assume we have 4 nodes, with 2 GPUs each
115107
global_resource_pool = RayResourcePool(process_on_nodes=[4, 4])
116-
global_resource_pool.get_placement_groups(device_name=get_device_name())
108+
global_resource_pool.get_placement_groups()
117109

118110
# first 2 gpus for actor_1, last 6 gpus for actor_2
119111
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(
@@ -122,12 +114,8 @@ def test_split_resource_pool_with_split_size_list_cross_nodes():
122114
)
123115
actor_cls_1 = RayClassWithInitArgs(cls=Actor, worker_id=0)
124116
actor_cls_2 = RayClassWithInitArgs(cls=Actor, worker_id=100)
125-
actor_worker_1 = RayWorkerGroup(
126-
resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1, device_name=get_device_name()
127-
)
128-
actor_worker_2 = RayWorkerGroup(
129-
resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2, device_name=get_device_name()
130-
)
117+
actor_worker_1 = RayWorkerGroup(resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1)
118+
actor_worker_2 = RayWorkerGroup(resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2)
131119

132120
assert actor_worker_1.world_size == 2
133121
assert actor_worker_2.world_size == 6
@@ -149,7 +137,7 @@ def test_split_resource_pool_with_split_twice():
149137

150138
# assume we have 4 nodes, with 2 GPUs each
151139
global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2])
152-
global_resource_pool.get_placement_groups(device_name=get_device_name())
140+
global_resource_pool.get_placement_groups()
153141

154142
# actors with [2, 1, 1, 1, 1, 2] (split twice)
155143
rp_1, rp_2, rp_3 = split_resource_pool(

verl/single_controller/ray/base.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from verl.protocol import DataProto, _padding_size_key
2828
from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
2929
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch
30-
from verl.utils.device import get_device_name
30+
from verl.utils.device import get_device_name, get_resource_name
3131
from verl.utils.py_functional import temp_env_var
3232

3333
__all__ = ["Worker"]
@@ -112,21 +112,17 @@ def __init__(
112112
self.detached = detached
113113
self.accelerator_type = accelerator_type
114114

115-
def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"):
115+
def get_placement_groups(self, strategy="STRICT_PACK", name=None):
116116
if self.pgs is not None:
117117
return self.pgs
118118

119119
pg_name_prefix = (
120120
name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
121121
)
122-
# print(f"pg_name_prefix = {pg_name_prefix}")
123-
if device_name == "npu":
124-
device_name = "NPU"
125-
elif device_name == "cuda":
126-
device_name = "GPU"
127122

128123
bundle = {"CPU": self.max_colocate_count}
129124
if self.use_gpu:
125+
device_name = get_resource_name()
130126
bundle[device_name] = 1
131127
if self.accelerator_type is not None:
132128
bundle[self.accelerator_type] = 1e-4
@@ -249,9 +245,7 @@ def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResour
249245
merged = type(rp1)(
250246
new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}", rp1.max_colocate_count, rp1.detached
251247
)
252-
merged.pgs = rp1.get_placement_groups(device_name=get_device_name()) + rp2.get_placement_groups(
253-
device_name=get_device_name()
254-
)
248+
merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()
255249

256250
return merged
257251

@@ -293,7 +287,7 @@ def __call__(
293287
use_gpu: bool = True,
294288
num_gpus=1,
295289
sharing_with=None,
296-
device_name="cuda",
290+
**kwargs,
297291
) -> Any:
298292
"""Create and return a Ray actor with the configured options.
299293
@@ -303,7 +297,7 @@ def __call__(
303297
use_gpu: Whether to use GPU resources
304298
num_gpus: Number of GPUs to allocate
305299
sharing_with: Actor to share resources with
306-
device_name: Device for training
300+
kwargs: Additional keyword arguments
307301
308302
Returns:
309303
A Ray actor handle with the configured options
@@ -321,10 +315,12 @@ def __call__(
321315
}
322316
options.update(self._options)
323317

324-
if use_gpu and device_name == "cuda":
325-
options["num_gpus"] = num_gpus
326-
if use_gpu and device_name == "npu":
327-
options["resources"] = {"NPU": num_gpus}
318+
if use_gpu:
319+
device_name = get_device_name()
320+
if device_name == "cuda":
321+
options["num_gpus"] = num_gpus
322+
elif device_name == "npu":
323+
options["resources"] = {"NPU": num_gpus}
328324

329325
if len(self._additional_resource) > 1:
330326
for k, v in self._additional_resource.items():
@@ -380,7 +376,7 @@ def __init__(
380376
# if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to
381377
# this WorkerGroup.
382378
self.sub_cls_name = ""
383-
self.device_name = kwargs.get("device_name", "cuda")
379+
self.device_name = kwargs.get("device_name", get_device_name())
384380
self.profile_steps = kwargs.get("profile_steps", None)
385381
self.worker_nsight_options = kwargs.get("worker_nsight_options", None)
386382
self.customized_worker_env = kwargs.get("worker_env", {})
@@ -469,7 +465,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d
469465
strategy = "PACK"
470466
if bin_pack:
471467
strategy = "STRICT_PACK"
472-
pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name)
468+
pgs = resource_pool.get_placement_groups(strategy=strategy)
473469
world_size = resource_pool.world_size
474470
self._world_size = world_size
475471
# cia.add_kwarg("_world_size", world_size)
@@ -505,7 +501,7 @@ def _init_with_subresource_pool(self, resource_pool, ray_cls_with_init, bin_pack
505501
strategy = "PACK"
506502
if bin_pack:
507503
strategy = "STRICT_PACK"
508-
pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name)
504+
pgs = resource_pool.get_placement_groups(strategy=strategy)
509505
world_size = resource_pool.world_size
510506
self._world_size = world_size
511507

0 commit comments

Comments
 (0)