1
1
"""Classes for DMPO agent distributed with Ray."""
2
2
3
- from typing import Optional , Iterator , Callable
3
+ from typing import Iterator , Callable
4
4
import socket
5
5
import dataclasses
6
6
import copy
@@ -61,6 +61,7 @@ class DMPOConfig:
61
61
replay_table_name : str = reverb_adders .DEFAULT_PRIORITY_TABLE
62
62
print_fn : Callable = logging .info
63
63
userdata : dict | None = None
64
+ actor_observation_callback : Callable | None = None
64
65
65
66
66
67
class ReplayServer ():
@@ -241,10 +242,9 @@ def __init__(
241
242
dmpo_config ,
242
243
actor_or_evaluator = 'actor' ,
243
244
label = None ,
244
- ray_head_node_ip : Optional [str ] = None ,
245
- egl_device_id_head_node : Optional [list ] = None , # ['1', '2', '3']
246
- egl_device_id_worker_node : Optional [
247
- list ] = None , # ['0', '1', '2', '3']
245
+ ray_head_node_ip : str | None = None ,
246
+ egl_device_id_head_node : list | None = None ,
247
+ egl_device_id_worker_node : list | None = None ,
248
248
):
249
249
"""The actor process."""
250
250
@@ -258,12 +258,8 @@ def __init__(
258
258
running_on_head_node = True
259
259
break
260
260
if running_on_head_node :
261
- # egl_device_id = egl_device_id_head_node[
262
- # actor_count % len(egl_device_id_head_node)]
263
261
egl_device_id = np .random .choice (egl_device_id_head_node )
264
262
else :
265
- # egl_device_id = egl_device_id_worker_node[
266
- # actor_count % len(egl_device_id_worker_node)]
267
263
egl_device_id = np .random .choice (egl_device_id_worker_node )
268
264
os .environ ['MUJOCO_EGL_DEVICE_ID' ] = str (egl_device_id )
269
265
@@ -314,9 +310,11 @@ def wrapped_network_factory(action_spec):
314
310
save_data = self ._config .logger_save_csv_data
315
311
316
312
# Create the agent.
317
- actor = self ._make_actor (policy_network = policy_network ,
318
- adder = adder ,
319
- variable_source = variable_source )
313
+ actor = self ._make_actor (
314
+ policy_network = policy_network ,
315
+ adder = adder ,
316
+ variable_source = variable_source ,
317
+ observation_callback = self ._config .actor_observation_callback )
320
318
321
319
# Create logger and counter; actors will not spam bigtable.
322
320
counter = counting .Counter (parent = counter , prefix = actor_or_evaluator )
@@ -347,8 +345,9 @@ def isready(self):
347
345
def _make_actor (
348
346
self ,
349
347
policy_network : snt .Module ,
350
- adder : Optional [adders .Adder ] = None ,
351
- variable_source : Optional [core .VariableSource ] = None ,
348
+ adder : adders .Adder | None = None ,
349
+ variable_source : core .VariableSource | None = None ,
350
+ observation_callback : Callable | None = None ,
352
351
):
353
352
"""Create an actor instance."""
354
353
if variable_source :
@@ -369,7 +368,8 @@ def _make_actor(
369
368
return DelayedFeedForwardActor (policy_network = policy_network ,
370
369
adder = adder ,
371
370
variable_client = variable_client ,
372
- action_delay = None )
371
+ action_delay = None ,
372
+ observation_callback = observation_callback )
373
373
374
374
def _make_adder (self , replay_client : reverb .Client ) -> adders .Adder :
375
375
"""Create an adder which records data generated by the actor/environment."""
0 commit comments