Skip to content

Commit e94a60b

Browse files
authored
[RLLib] Pass large AlgorithmConfig by reference to RolloutWorker (ray-project#50688)
Signed-off-by: Jiajun Yao <[email protected]>
1 parent 245ddfc commit e94a60b

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

rllib/env/env_runner_group.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
self._env_creator = env_creator
131131
self._policy_class = default_policy_class
132132
self._remote_config = config
133+
self._remote_config_obj_ref = ray.put(self._remote_config)
133134
self._remote_args = {
134135
"num_cpus": self._remote_config.num_cpus_per_env_runner,
135136
"num_gpus": self._remote_config.num_gpus_per_env_runner,
@@ -665,7 +666,10 @@ def add_workers(self, num_workers: int, validate: bool = False) -> None:
665666
validate_env=None,
666667
worker_index=old_num_workers + i + 1,
667668
num_workers=old_num_workers + num_workers,
668-
config=self._remote_config,
669+
# self._remote_config can be large
670+
# and it's best practice to pass it by reference
671+
# instead of value (https://docs.ray.io/en/latest/ray-core/patterns/pass-large-arg-by-value.html)
672+
config=self._remote_config_obj_ref,
669673
)
670674
for i in range(num_workers)
671675
]

rllib/evaluation/tests/test_env_runner_v2.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,32 @@
2525
register_env("basic_multiagent", lambda _: BasicMultiAgent(2))
2626

2727

28+
def _get_mapper():
29+
# Note(Artur): This was originally part of the unittest.TestCase.setUpClass
30+
# method but caused trouble when serializing the config because we ended up
31+
# serializing `self`, which is an instance of unittest.TestCase.
32+
33+
# When dealing with two policies in these tests, simply alternate between the 2
34+
# policies to make sure we have data for inference for both policies for each
35+
# step.
36+
class AlternatePolicyMapper:
37+
def __init__(self):
38+
self.policies = ["one", "two"]
39+
self.next = 0
40+
41+
def map(self):
42+
p = self.policies[self.next]
43+
self.next = 1 - self.next
44+
return p
45+
46+
return AlternatePolicyMapper()
47+
48+
2849
class TestEnvRunnerV2(unittest.TestCase):
2950
@classmethod
3051
def setUpClass(cls):
3152
ray.init()
3253

33-
# When dealing with two policies in these tests, simply alternate between the 2
34-
# policies to make sure we have data for inference for both policies for each
35-
# step.
36-
class AlternatePolicyMapper:
37-
def __init__(self):
38-
self.policies = ["one", "two"]
39-
self.next = 0
40-
41-
def map(self):
42-
p = self.policies[self.next]
43-
self.next = 1 - self.next
44-
return p
45-
46-
cls.mapper = AlternatePolicyMapper()
47-
4854
@classmethod
4955
def tearDownClass(cls):
5056
ray.shutdown()
@@ -215,6 +221,8 @@ def __init__(self, *args, **kwargs):
215221
self.view_requirements["rewards"].used_for_compute_actions = False
216222
self.view_requirements["terminateds"].used_for_compute_actions = False
217223

224+
mapper = _get_mapper()
225+
218226
config = (
219227
PPOConfig()
220228
.api_stack(
@@ -240,7 +248,7 @@ def __init__(self, *args, **kwargs):
240248
policy_class=RandomPolicyTwo,
241249
),
242250
},
243-
policy_mapping_fn=lambda *args, **kwargs: self.mapper.map(),
251+
policy_mapping_fn=lambda *args, **kwargs: mapper.map(),
244252
policies_to_train=["one"],
245253
count_steps_by="agent_steps",
246254
)
@@ -316,6 +324,7 @@ def on_create_policy(self, *, policy_id, policy) -> None:
316324
_ = rollout_worker.sample()
317325

318326
def test_start_episode(self):
327+
mapper = _get_mapper()
319328
config = (
320329
PPOConfig()
321330
.api_stack(
@@ -341,7 +350,7 @@ def test_start_episode(self):
341350
policy_class=RandomPolicy,
342351
),
343352
},
344-
policy_mapping_fn=lambda *args, **kwargs: self.mapper.map(),
353+
policy_mapping_fn=lambda *args, **kwargs: mapper.map(),
345354
policies_to_train=["one"],
346355
count_steps_by="agent_steps",
347356
)
@@ -373,6 +382,7 @@ def test_start_episode(self):
373382
self.assertEqual(env_runner._active_episodes[0].total_agent_steps, 2)
374383

375384
def test_env_runner_output(self):
385+
mapper = _get_mapper()
376386
# Test if we can produce RolloutMetrics just by stepping
377387
config = (
378388
PPOConfig()
@@ -399,7 +409,7 @@ def test_env_runner_output(self):
399409
policy_class=RandomPolicy,
400410
),
401411
},
402-
policy_mapping_fn=lambda *args, **kwargs: self.mapper.map(),
412+
policy_mapping_fn=lambda *args, **kwargs: mapper.map(),
403413
policies_to_train=["one"],
404414
count_steps_by="agent_steps",
405415
)
@@ -434,6 +444,7 @@ def on_episode_end(
434444
# We should see an error episode.
435445
assert isinstance(episode, Exception)
436446

447+
mapper = _get_mapper()
437448
# Test if we can produce RolloutMetrics just by stepping
438449
config = (
439450
PPOConfig()
@@ -460,7 +471,7 @@ def on_episode_end(
460471
policy_class=RandomPolicy,
461472
),
462473
},
463-
policy_mapping_fn=lambda *args, **kwargs: self.mapper.map(),
474+
policy_mapping_fn=lambda *args, **kwargs: mapper.map(),
464475
policies_to_train=["one"],
465476
count_steps_by="agent_steps",
466477
)

0 commit comments

Comments
 (0)