From fac12375dd6287d0ed4b6f0c218f415d5050a06b Mon Sep 17 00:00:00 2001 From: Tucker Date: Tue, 11 Apr 2023 09:40:41 -0400 Subject: [PATCH] Add missing docstrings. --- smarts/ray/sensors/ray_sensor_resolver.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/smarts/ray/sensors/ray_sensor_resolver.py b/smarts/ray/sensors/ray_sensor_resolver.py index bdd7133fc2..80863e7614 100644 --- a/smarts/ray/sensors/ray_sensor_resolver.py +++ b/smarts/ray/sensors/ray_sensor_resolver.py @@ -71,7 +71,7 @@ def get_ray_worker_actors(self, count: int): if len(self._current_workers) != count: # we need to cache because using options(name) is extremely slow self._current_workers = [ - ProcessWorker.options( + RayProcessWorker.options( name=f"sensor_worker_{i}", get_if_exists=True ).remote() for i in range(count) @@ -109,7 +109,7 @@ def observe( # Start remote tasks agent_ids_for_grouping = list(agent_ids) agent_groups = [ - agent_ids_for_grouping[i::len_workers] for i in range(len_workers) + frozenset(agent_ids_for_grouping[i::len_workers]) for i in range(len_workers) ] for i, agent_group in enumerate(agent_groups): if not agent_group: @@ -162,14 +162,29 @@ def step(self, sim_frame, sensor_states): @ray.remote -class ProcessWorker: +class RayProcessWorker: + """A `ray` based process worker for parallel operation on sensors.""" def __init__(self) -> None: self._simulation_local_constants: Optional[SimulationLocalConstants] = None def update_local_constants(self, sim_local_constants): + """Updates the process worker. + + Args: + sim_local_constants (SimulationLocalConstants | None): The current simulation reset state. + """ self._simulation_local_constants = loads(sim_local_constants) def do_work(self, remote_sim_frame, agent_ids): + """Run the sensors against the current simulation state. + + Args: + remote_sim_frame (SimulationFrame): The current simulation state. + agent_ids (set[str]): The agent ids to operate on. + + Returns: + tuple[dict, dict, dict]: The updated sensor states: (observations, dones, updated_sensors) + """ sim_frame = loads(remote_sim_frame) return Sensors.observe_serializable_sensor_batch( sim_frame, self._simulation_local_constants, agent_ids