@@ -51,18 +51,14 @@ def test_split_resource_pool_with_split_size():
5151 ray .init ()
5252 # assume we have 2 nodes, with 4 GPUs each
5353 global_resource_pool = RayResourcePool (process_on_nodes = [4 , 4 ])
54- global_resource_pool .get_placement_groups (device_name = get_device_name () )
54+ global_resource_pool .get_placement_groups ()
5555
5656 # first 4 gpus for actor_1, last 4 gpus for actor_2
5757 actor_1_resource_pool , actor_2_resource_pool = split_resource_pool (resource_pool = global_resource_pool , split_size = 4 )
5858 actor_cls_1 = RayClassWithInitArgs (cls = Actor , worker_id = 0 )
5959 actor_cls_2 = RayClassWithInitArgs (cls = Actor , worker_id = 100 )
60- actor_worker_1 = RayWorkerGroup (
61- resource_pool = actor_1_resource_pool , ray_cls_with_init = actor_cls_1 , device_name = get_device_name ()
62- )
63- actor_worker_2 = RayWorkerGroup (
64- resource_pool = actor_2_resource_pool , ray_cls_with_init = actor_cls_2 , device_name = get_device_name ()
65- )
60+ actor_worker_1 = RayWorkerGroup (resource_pool = actor_1_resource_pool , ray_cls_with_init = actor_cls_1 )
61+ actor_worker_2 = RayWorkerGroup (resource_pool = actor_2_resource_pool , ray_cls_with_init = actor_cls_2 )
6662 assert actor_worker_1 .world_size == 4
6763 assert actor_worker_2 .world_size == 4
6864
@@ -79,7 +75,7 @@ def test_split_resource_pool_with_split_size_list():
7975 ray .init ()
8076 # assume we have 4 nodes, with 2 GPUs each
8177 global_resource_pool = RayResourcePool (process_on_nodes = [2 , 2 , 2 , 2 ])
82- global_resource_pool .get_placement_groups (device_name = get_device_name () )
78+ global_resource_pool .get_placement_groups ()
8379
8480 # first 2 gpus for actor_1, last 6 gpus for actor_2
8581 actor_1_resource_pool , actor_2_resource_pool = split_resource_pool (
@@ -88,12 +84,8 @@ def test_split_resource_pool_with_split_size_list():
8884 )
8985 actor_cls_1 = RayClassWithInitArgs (cls = Actor , worker_id = 0 )
9086 actor_cls_2 = RayClassWithInitArgs (cls = Actor , worker_id = 100 )
91- actor_worker_1 = RayWorkerGroup (
92- resource_pool = actor_1_resource_pool , ray_cls_with_init = actor_cls_1 , device_name = get_device_name ()
93- )
94- actor_worker_2 = RayWorkerGroup (
95- resource_pool = actor_2_resource_pool , ray_cls_with_init = actor_cls_2 , device_name = get_device_name ()
96- )
87+ actor_worker_1 = RayWorkerGroup (resource_pool = actor_1_resource_pool , ray_cls_with_init = actor_cls_1 )
88+ actor_worker_2 = RayWorkerGroup (resource_pool = actor_2_resource_pool , ray_cls_with_init = actor_cls_2 )
9789 assert actor_worker_1 .world_size == 2
9890 assert actor_worker_2 .world_size == 6
9991
@@ -113,7 +105,7 @@ def test_split_resource_pool_with_split_size_list_cross_nodes():
113105 ray .init ()
114106 # assume we have 4 nodes, with 2 GPUs each
115107 global_resource_pool = RayResourcePool (process_on_nodes = [4 , 4 ])
116- global_resource_pool .get_placement_groups (device_name = get_device_name () )
108+ global_resource_pool .get_placement_groups ()
117109
118110 # first 2 gpus for actor_1, last 6 gpus for actor_2
119111 actor_1_resource_pool , actor_2_resource_pool = split_resource_pool (
@@ -122,12 +114,8 @@ def test_split_resource_pool_with_split_size_list_cross_nodes():
122114 )
123115 actor_cls_1 = RayClassWithInitArgs (cls = Actor , worker_id = 0 )
124116 actor_cls_2 = RayClassWithInitArgs (cls = Actor , worker_id = 100 )
125- actor_worker_1 = RayWorkerGroup (
126- resource_pool = actor_1_resource_pool , ray_cls_with_init = actor_cls_1 , device_name = get_device_name ()
127- )
128- actor_worker_2 = RayWorkerGroup (
129- resource_pool = actor_2_resource_pool , ray_cls_with_init = actor_cls_2 , device_name = get_device_name ()
130- )
117+ actor_worker_1 = RayWorkerGroup (resource_pool = actor_1_resource_pool , ray_cls_with_init = actor_cls_1 )
118+ actor_worker_2 = RayWorkerGroup (resource_pool = actor_2_resource_pool , ray_cls_with_init = actor_cls_2 )
131119
132120 assert actor_worker_1 .world_size == 2
133121 assert actor_worker_2 .world_size == 6
@@ -149,7 +137,7 @@ def test_split_resource_pool_with_split_twice():
149137
150138 # assume we have 4 nodes, with 2 GPUs each
151139 global_resource_pool = RayResourcePool (process_on_nodes = [2 , 2 , 2 , 2 ])
152- global_resource_pool .get_placement_groups (device_name = get_device_name () )
140+ global_resource_pool .get_placement_groups ()
153141
154142 # actors with [2, 1, 1, 1, 1, 2] (split twice)
155143 rp_1 , rp_2 , rp_3 = split_resource_pool (
0 commit comments