Skip to content

Commit a82d7e6

Browse files
committed
Update types and propagate callback.
1 parent d757552 commit a82d7e6

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

flybody/agents/ray_distributed_dmpo.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Classes for DMPO agent distributed with Ray."""
22

3-
from typing import Optional, Iterator, Callable
3+
from typing import Iterator, Callable
44
import socket
55
import dataclasses
66
import copy
@@ -61,6 +61,7 @@ class DMPOConfig:
6161
replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE
6262
print_fn: Callable = logging.info
6363
userdata: dict | None = None
64+
actor_observation_callback: Callable | None = None
6465

6566

6667
class ReplayServer():
@@ -241,10 +242,9 @@ def __init__(
241242
dmpo_config,
242243
actor_or_evaluator='actor',
243244
label=None,
244-
ray_head_node_ip: Optional[str] = None,
245-
egl_device_id_head_node: Optional[list] = None, # ['1', '2', '3']
246-
egl_device_id_worker_node: Optional[
247-
list] = None, # ['0', '1', '2', '3']
245+
ray_head_node_ip: str | None = None,
246+
egl_device_id_head_node: list | None = None,
247+
egl_device_id_worker_node: list | None = None,
248248
):
249249
"""The actor process."""
250250

@@ -258,12 +258,8 @@ def __init__(
258258
running_on_head_node = True
259259
break
260260
if running_on_head_node:
261-
# egl_device_id = egl_device_id_head_node[
262-
# actor_count % len(egl_device_id_head_node)]
263261
egl_device_id = np.random.choice(egl_device_id_head_node)
264262
else:
265-
# egl_device_id = egl_device_id_worker_node[
266-
# actor_count % len(egl_device_id_worker_node)]
267263
egl_device_id = np.random.choice(egl_device_id_worker_node)
268264
os.environ['MUJOCO_EGL_DEVICE_ID'] = str(egl_device_id)
269265

@@ -314,9 +310,11 @@ def wrapped_network_factory(action_spec):
314310
save_data = self._config.logger_save_csv_data
315311

316312
# Create the agent.
317-
actor = self._make_actor(policy_network=policy_network,
318-
adder=adder,
319-
variable_source=variable_source)
313+
actor = self._make_actor(
314+
policy_network=policy_network,
315+
adder=adder,
316+
variable_source=variable_source,
317+
observation_callback=self._config.actor_observation_callback)
320318

321319
# Create logger and counter; actors will not spam bigtable.
322320
counter = counting.Counter(parent=counter, prefix=actor_or_evaluator)
@@ -347,8 +345,9 @@ def isready(self):
347345
def _make_actor(
348346
self,
349347
policy_network: snt.Module,
350-
adder: Optional[adders.Adder] = None,
351-
variable_source: Optional[core.VariableSource] = None,
348+
adder: adders.Adder | None = None,
349+
variable_source: core.VariableSource | None = None,
350+
observation_callback: Callable | None = None,
352351
):
353352
"""Create an actor instance."""
354353
if variable_source:
@@ -369,7 +368,8 @@ def _make_actor(
369368
return DelayedFeedForwardActor(policy_network=policy_network,
370369
adder=adder,
371370
variable_client=variable_client,
372-
action_delay=None)
371+
action_delay=None,
372+
observation_callback=observation_callback)
373373

374374
def _make_adder(self, replay_client: reverb.Client) -> adders.Adder:
375375
"""Create an adder which records data generated by the actor/environment."""

0 commit comments

Comments
 (0)