2525register_env ("basic_multiagent" , lambda _ : BasicMultiAgent (2 ))
2626
2727
28+ def _get_mapper ():
29+ # Note(Artur): This was originally part of the unittest.TestCase.setUpClass
30+ # method but caused trouble when serializing the config because we ended up
31+ # serializing `self`, which is an instance of unittest.TestCase.
32+
33+ # When dealing with two policies in these tests, simply alternate between the 2
34+ # policies to make sure we have data for inference for both policies for each
35+ # step.
36+ class AlternatePolicyMapper :
37+ def __init__ (self ):
38+ self .policies = ["one" , "two" ]
39+ self .next = 0
40+
41+ def map (self ):
42+ p = self .policies [self .next ]
43+ self .next = 1 - self .next
44+ return p
45+
46+ return AlternatePolicyMapper ()
47+
48+
2849class TestEnvRunnerV2 (unittest .TestCase ):
2950 @classmethod
3051 def setUpClass (cls ):
3152 ray .init ()
3253
33- # When dealing with two policies in these tests, simply alternate between the 2
34- # policies to make sure we have data for inference for both policies for each
35- # step.
36- class AlternatePolicyMapper :
37- def __init__ (self ):
38- self .policies = ["one" , "two" ]
39- self .next = 0
40-
41- def map (self ):
42- p = self .policies [self .next ]
43- self .next = 1 - self .next
44- return p
45-
46- cls .mapper = AlternatePolicyMapper ()
47-
4854 @classmethod
4955 def tearDownClass (cls ):
5056 ray .shutdown ()
@@ -215,6 +221,8 @@ def __init__(self, *args, **kwargs):
215221 self .view_requirements ["rewards" ].used_for_compute_actions = False
216222 self .view_requirements ["terminateds" ].used_for_compute_actions = False
217223
224+ mapper = _get_mapper ()
225+
218226 config = (
219227 PPOConfig ()
220228 .api_stack (
@@ -240,7 +248,7 @@ def __init__(self, *args, **kwargs):
240248 policy_class = RandomPolicyTwo ,
241249 ),
242250 },
243- policy_mapping_fn = lambda * args , ** kwargs : self . mapper .map (),
251+ policy_mapping_fn = lambda * args , ** kwargs : mapper .map (),
244252 policies_to_train = ["one" ],
245253 count_steps_by = "agent_steps" ,
246254 )
@@ -316,6 +324,7 @@ def on_create_policy(self, *, policy_id, policy) -> None:
316324 _ = rollout_worker .sample ()
317325
318326 def test_start_episode (self ):
327+ mapper = _get_mapper ()
319328 config = (
320329 PPOConfig ()
321330 .api_stack (
@@ -341,7 +350,7 @@ def test_start_episode(self):
341350 policy_class = RandomPolicy ,
342351 ),
343352 },
344- policy_mapping_fn = lambda * args , ** kwargs : self . mapper .map (),
353+ policy_mapping_fn = lambda * args , ** kwargs : mapper .map (),
345354 policies_to_train = ["one" ],
346355 count_steps_by = "agent_steps" ,
347356 )
@@ -373,6 +382,7 @@ def test_start_episode(self):
373382 self .assertEqual (env_runner ._active_episodes [0 ].total_agent_steps , 2 )
374383
375384 def test_env_runner_output (self ):
385+ mapper = _get_mapper ()
376386 # Test if we can produce RolloutMetrics just by stepping
377387 config = (
378388 PPOConfig ()
@@ -399,7 +409,7 @@ def test_env_runner_output(self):
399409 policy_class = RandomPolicy ,
400410 ),
401411 },
402- policy_mapping_fn = lambda * args , ** kwargs : self . mapper .map (),
412+ policy_mapping_fn = lambda * args , ** kwargs : mapper .map (),
403413 policies_to_train = ["one" ],
404414 count_steps_by = "agent_steps" ,
405415 )
@@ -434,6 +444,7 @@ def on_episode_end(
434444 # We should see an error episode.
435445 assert isinstance (episode , Exception )
436446
447+ mapper = _get_mapper ()
437448 # Test if we can produce RolloutMetrics just by stepping
438449 config = (
439450 PPOConfig ()
@@ -460,7 +471,7 @@ def on_episode_end(
460471 policy_class = RandomPolicy ,
461472 ),
462473 },
463- policy_mapping_fn = lambda * args , ** kwargs : self . mapper .map (),
474+ policy_mapping_fn = lambda * args , ** kwargs : mapper .map (),
464475 policies_to_train = ["one" ],
465476 count_steps_by = "agent_steps" ,
466477 )
0 commit comments