Skip to content

Commit 06da16e

Browse files
[RLlib] Add timers to env step, forward pass, and complete connector pipelines runs. (ray-project#51160)
1 parent 1ed623d commit 06da16e

File tree

9 files changed

+89
-29
lines changed

9 files changed

+89
-29
lines changed

rllib/connectors/connector_pipeline_v2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from ray.rllib.core.rl_module.rl_module import RLModule
88
from ray.rllib.utils.annotations import override
99
from ray.rllib.utils.checkpoints import Checkpointable
10-
from ray.rllib.utils.metrics import TIMERS, CONNECTOR_TIMERS
10+
from ray.rllib.utils.metrics import TIMERS, CONNECTOR_PIPELINE_TIMER, CONNECTOR_TIMERS
1111
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
12+
from ray.rllib.utils.metrics.utils import to_snake_case
1213
from ray.rllib.utils.typing import EpisodeType, StateDict
1314
from ray.util.annotations import PublicAPI
1415

@@ -95,6 +96,13 @@ def __call__(
9596
piece in the pipeline.
9697
"""
9798
shared_data = shared_data if shared_data is not None else {}
99+
full_stats = None
100+
if metrics:
101+
full_stats = metrics.log_time(
102+
kwargs.get("metrics_prefix_key", ()) + (CONNECTOR_PIPELINE_TIMER,)
103+
)
104+
full_stats.__enter__()
105+
98106
# Loop through connector pieces and call each one with the output of the
99107
# previous one. Thereby, time each connector piece's call.
100108
for connector in self.connectors:
@@ -104,7 +112,11 @@ def __call__(
104112
if metrics:
105113
stats = metrics.log_time(
106114
kwargs.get("metrics_prefix_key", ())
107-
+ (TIMERS, CONNECTOR_TIMERS, connector.__class__.__name__)
115+
+ (
116+
TIMERS,
117+
CONNECTOR_TIMERS,
118+
to_snake_case(connector.__class__.__name__),
119+
)
108120
)
109121
stats.__enter__()
110122

@@ -131,6 +143,9 @@ def __call__(
131143
f"the `data` arg passed in (either altered or unchanged)."
132144
)
133145

146+
if metrics:
147+
full_stats.__exit__(None, None, None)
148+
134149
return batch
135150

136151
def remove(self, name_or_class: Union[str, Type]):

rllib/connectors/learner/learner_connector_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ray.rllib.utils.annotations import override
55
from ray.rllib.utils.metrics import (
66
ALL_MODULES,
7+
LEARNER_CONNECTOR,
78
LEARNER_CONNECTOR_SUM_EPISODES_LENGTH_IN,
89
LEARNER_CONNECTOR_SUM_EPISODES_LENGTH_OUT,
910
)
@@ -42,7 +43,10 @@ def __call__(
4243
shared_data=shared_data if shared_data is not None else {},
4344
explore=explore,
4445
metrics=metrics,
45-
metrics_prefix_key=(ALL_MODULES,),
46+
metrics_prefix_key=(
47+
ALL_MODULES,
48+
LEARNER_CONNECTOR,
49+
),
4650
**kwargs,
4751
)
4852

rllib/core/learner/tests/test_learner_group.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
2323
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
2424
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
25-
from ray.rllib.utils.metrics import ALL_MODULES, TIMERS
25+
from ray.rllib.utils.metrics import ALL_MODULES, LEARNER_CONNECTOR
2626
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
2727
from ray.rllib.utils.test_utils import check
2828
from ray.util.timer import _Timer
@@ -474,8 +474,8 @@ def test_save_to_path_and_restore_from_path(self):
474474
results_2nd_update_with_break,
475475
results_2nd_update_without_break,
476476
):
477-
r1[ALL_MODULES].pop(TIMERS)
478-
r2[ALL_MODULES].pop(TIMERS)
477+
r1[ALL_MODULES].pop(LEARNER_CONNECTOR)
478+
r2[ALL_MODULES].pop(LEARNER_CONNECTOR)
479479
check(
480480
MetricsLogger.peek_results(results_2nd_update_with_break),
481481
MetricsLogger.peek_results(results_2nd_update_without_break),

rllib/env/env_runner.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from ray.rllib.core import COMPONENT_RL_MODULE
1010
from ray.rllib.utils.actor_manager import FaultAwareApply
1111
from ray.rllib.utils.framework import try_import_tf
12+
from ray.rllib.utils.metrics import ENV_RESET_TIMER, ENV_STEP_TIMER
13+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
1214
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
1315
from ray.rllib.utils.typing import StateDict, TensorType
1416
from ray.util.annotations import PublicAPI, DeveloperAPI
@@ -51,8 +53,10 @@ def __init__(self, *, config: "AlgorithmConfig", **kwargs):
5153
config: The AlgorithmConfig to use to setup this EnvRunner.
5254
**kwargs: Forward compatibility kwargs.
5355
"""
54-
self.config = config.copy(copy_frozen=False)
56+
self.config: AlgorithmConfig = config.copy(copy_frozen=False)
5557
self.env = None
58+
# Create a MetricsLogger object for logging custom stats.
59+
self.metrics: MetricsLogger = MetricsLogger()
5660

5761
super().__init__(**kwargs)
5862

@@ -160,9 +164,11 @@ def _try_env_reset(self):
160164
"""Tries resetting the env and - if an error orrurs - handles it gracefully."""
161165
# Try to reset.
162166
try:
163-
obs, infos = self.env.reset(
164-
seed=self.config.seed and self.config.seed + (self.worker_index or 0),
165-
)
167+
with self.metrics.log_time(ENV_RESET_TIMER):
168+
obs, infos = self.env.reset(
169+
seed=self.config.seed
170+
and self.config.seed + (self.worker_index or 0),
171+
)
166172
# Everything ok -> return.
167173
return obs, infos
168174
# Error.
@@ -183,7 +189,8 @@ def _try_env_reset(self):
183189
def _try_env_step(self, actions):
184190
"""Tries stepping the env and - if an error orrurs - handles it gracefully."""
185191
try:
186-
results = self.env.step(actions)
192+
with self.metrics.log_time(ENV_STEP_TIMER):
193+
results = self.env.step(actions)
187194
return results
188195
except Exception as e:
189196
if self.config.restart_failed_sub_environments:

rllib/env/multi_agent_env_runner.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
from ray.rllib.utils.deprecation import Deprecated
3131
from ray.rllib.utils.framework import get_device, try_import_torch
3232
from ray.rllib.utils.metrics import (
33+
ENV_TO_MODULE_CONNECTOR,
3334
EPISODE_DURATION_SEC_MEAN,
3435
EPISODE_LEN_MAX,
3536
EPISODE_LEN_MEAN,
3637
EPISODE_LEN_MIN,
3738
EPISODE_RETURN_MAX,
3839
EPISODE_RETURN_MEAN,
3940
EPISODE_RETURN_MIN,
41+
MODULE_TO_ENV_CONNECTOR,
4042
NUM_AGENT_STEPS_SAMPLED,
4143
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
4244
NUM_ENV_STEPS_SAMPLED,
@@ -45,11 +47,11 @@
4547
NUM_EPISODES_LIFETIME,
4648
NUM_MODULE_STEPS_SAMPLED,
4749
NUM_MODULE_STEPS_SAMPLED_LIFETIME,
50+
RLMODULE_INFERENCE_TIMER,
4851
SAMPLE_TIMER,
4952
TIME_BETWEEN_SAMPLING,
5053
WEIGHTS_SEQ_NO,
5154
)
52-
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
5355
from ray.rllib.utils.pre_checks.env import check_multiagent_environments
5456
from ray.rllib.utils.typing import EpisodeID, ModelWeights, ResultDict, StateDict
5557
from ray.tune.registry import ENV_CREATOR, _global_registry
@@ -88,8 +90,6 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
8890
self.worker_index: int = kwargs.get("worker_index")
8991
self.tune_trial_id: str = kwargs.get("tune_trial_id")
9092

91-
# Set up all metrics-related structures and counters.
92-
self.metrics: Optional[MetricsLogger] = None
9393
self._setup_metrics()
9494

9595
# Create our callbacks object.
@@ -310,11 +310,13 @@ def _sample(
310310
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
311311
+ ts
312312
) * (self.config.num_env_runners or 1)
313-
to_env = self.module.forward_exploration(
314-
to_module, t=global_env_steps_lifetime
315-
)
313+
with self.metrics.log_time(RLMODULE_INFERENCE_TIMER):
314+
to_env = self.module.forward_exploration(
315+
to_module, t=global_env_steps_lifetime
316+
)
316317
else:
317-
to_env = self.module.forward_inference(to_module)
318+
with self.metrics.log_time(RLMODULE_INFERENCE_TIMER):
319+
to_env = self.module.forward_inference(to_module)
318320

319321
# Module-to-env connector.
320322
to_env = self._module_to_env(
@@ -324,6 +326,7 @@ def _sample(
324326
explore=explore,
325327
shared_data=shared_data,
326328
metrics=self.metrics,
329+
metrics_prefix_key=(MODULE_TO_ENV_CONNECTOR,),
327330
)
328331
# In case all environments had been terminated `to_module` will be
329332
# empty and no actions are needed b/c we reset all environemnts.
@@ -453,19 +456,23 @@ def _sample(
453456
# Run the env-to-module connector pipeline for all done episodes.
454457
# Note, this is needed to postprocess last-step data, e.g. if the
455458
# user uses a connector that one-hot encodes observations.
459+
# Note, this pipeline run is not timed as the number of episodes
460+
# can differ from `num_envs_per_env_runner` and would bias time
461+
# measurements.
456462
self._env_to_module(
457463
episodes=done_episodes_to_run_env_to_module,
458464
explore=explore,
459465
rl_module=self.module,
460466
shared_data=shared_data,
461-
metrics=self.metrics,
467+
metrics=None,
462468
)
463469
self._cached_to_module = self._env_to_module(
464470
episodes=episodes,
465471
explore=explore,
466472
rl_module=self.module,
467473
shared_data=shared_data,
468474
metrics=self.metrics,
475+
metrics_prefix_key=(ENV_TO_MODULE_CONNECTOR,),
469476
)
470477

471478
# Numpy'ize the done episodes after running the connector pipeline. Note,
@@ -544,6 +551,7 @@ def _reset_envs(self, episodes, shared_data, explore):
544551
explore=explore,
545552
shared_data=shared_data,
546553
metrics=self.metrics,
554+
metrics_key_prefix=(ENV_TO_MODULE_CONNECTOR,),
547555
)
548556

549557
# Call `on_episode_start()` callbacks (always after reset).
@@ -871,8 +879,6 @@ def stop(self):
871879
self.env.close()
872880

873881
def _setup_metrics(self):
874-
self.metrics = MetricsLogger()
875-
876882
self._done_episodes_for_metrics: List[MultiAgentEpisode] = []
877883
self._ongoing_episodes_for_metrics: DefaultDict[
878884
EpisodeID, List[MultiAgentEpisode]

rllib/env/single_agent_env_runner.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
from ray.rllib.utils.deprecation import Deprecated
3333
from ray.rllib.utils.framework import get_device
3434
from ray.rllib.utils.metrics import (
35+
ENV_TO_MODULE_CONNECTOR,
3536
EPISODE_DURATION_SEC_MEAN,
3637
EPISODE_LEN_MAX,
3738
EPISODE_LEN_MEAN,
3839
EPISODE_LEN_MIN,
3940
EPISODE_RETURN_MAX,
4041
EPISODE_RETURN_MEAN,
4142
EPISODE_RETURN_MIN,
43+
MODULE_TO_ENV_CONNECTOR,
4244
NUM_AGENT_STEPS_SAMPLED,
4345
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
4446
NUM_ENV_STEPS_SAMPLED,
@@ -47,11 +49,11 @@
4749
NUM_EPISODES_LIFETIME,
4850
NUM_MODULE_STEPS_SAMPLED,
4951
NUM_MODULE_STEPS_SAMPLED_LIFETIME,
52+
RLMODULE_INFERENCE_TIMER,
5053
SAMPLE_TIMER,
5154
TIME_BETWEEN_SAMPLING,
5255
WEIGHTS_SEQ_NO,
5356
)
54-
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
5557
from ray.rllib.utils.spaces.space_utils import unbatch
5658
from ray.rllib.utils.typing import EpisodeID, ResultDict, StateDict
5759
from ray.tune.registry import ENV_CREATOR, _global_registry
@@ -80,9 +82,6 @@ def __init__(self, *, config: AlgorithmConfig, **kwargs):
8082
self.num_workers: int = kwargs.get("num_workers", self.config.num_env_runners)
8183
self.tune_trial_id: str = kwargs.get("tune_trial_id")
8284

83-
# Create a MetricsLogger object for logging custom stats.
84-
self.metrics = MetricsLogger()
85-
8685
# Create our callbacks object.
8786
self._callbacks: List[RLlibCallback] = [
8887
cls() for cls in force_list(self.config.callbacks_class)
@@ -296,11 +295,13 @@ def _sample(
296295
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
297296
+ ts
298297
) * (self.config.num_env_runners or 1)
299-
to_env = self.module.forward_exploration(
300-
to_module, t=global_env_steps_lifetime
301-
)
298+
with self.metrics.log_time(RLMODULE_INFERENCE_TIMER):
299+
to_env = self.module.forward_exploration(
300+
to_module, t=global_env_steps_lifetime
301+
)
302302
else:
303-
to_env = self.module.forward_inference(to_module)
303+
with self.metrics.log_time(RLMODULE_INFERENCE_TIMER):
304+
to_env = self.module.forward_inference(to_module)
304305

305306
# Module-to-env connector.
306307
to_env = self._module_to_env(
@@ -310,6 +311,7 @@ def _sample(
310311
explore=explore,
311312
shared_data=shared_data,
312313
metrics=self.metrics,
314+
metrics_prefix_key=(MODULE_TO_ENV_CONNECTOR,),
313315
)
314316

315317
# Extract the (vectorized) actions (to be sent to the env) from the
@@ -370,6 +372,7 @@ def _sample(
370372
rl_module=self.module,
371373
shared_data=shared_data,
372374
metrics=self.metrics,
375+
metrics_prefix_key=(ENV_TO_MODULE_CONNECTOR,),
373376
)
374377

375378
for env_index in range(self.num_envs):
@@ -738,6 +741,7 @@ def _reset_envs(self, episodes, shared_data, explore):
738741
explore=explore,
739742
shared_data=shared_data,
740743
metrics=self.metrics,
744+
metrics_prefix_key=(ENV_TO_MODULE_CONNECTOR,),
741745
)
742746

743747
# Call `on_episode_start()` callbacks (always after reset).

rllib/utils/metrics/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,21 @@
160160
GRAD_WAIT_TIMER = "grad_wait"
161161
SAMPLE_TIMER = "sample" # @OldAPIStack
162162
ENV_RUNNER_SAMPLING_TIMER = "env_runner_sampling_timer"
163+
ENV_RESET_TIMER = "env_reset_timer"
164+
ENV_STEP_TIMER = "env_step_timer"
165+
ENV_TO_MODULE_CONNECTOR = "env_to_module_connector"
166+
RLMODULE_INFERENCE_TIMER = "rlmodule_inference_timer"
167+
MODULE_TO_ENV_CONNECTOR = "module_to_env_connector"
163168
OFFLINE_SAMPLING_TIMER = "offline_sampling_timer"
164169
REPLAY_BUFFER_ADD_DATA_TIMER = "replay_buffer_add_data_timer"
165170
REPLAY_BUFFER_SAMPLE_TIMER = "replay_buffer_sampling_timer"
166171
REPLAY_BUFFER_UPDATE_PRIOS_TIMER = "replay_buffer_update_prios_timer"
172+
LEARNER_CONNECTOR = "learner_connector"
167173
LEARNER_UPDATE_TIMER = "learner_update_timer"
168174
LEARN_ON_BATCH_TIMER = "learn" # @OldAPIStack
169175
LOAD_BATCH_TIMER = "load"
170176
TARGET_NET_UPDATE_TIMER = "target_net_update"
177+
CONNECTOR_PIPELINE_TIMER = "connector_pipeline_timer"
171178
CONNECTOR_TIMERS = "connectors"
172179

173180
# Learner.

rllib/utils/metrics/stats.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from ray.rllib.utils import force_list
1010
from ray.rllib.utils.framework import try_import_tf, try_import_torch
1111
from ray.rllib.utils.numpy import convert_to_numpy
12+
from ray.util.annotations import DeveloperAPI
1213

1314
_, tf, _ = try_import_tf()
1415
torch, _ = try_import_torch()
1516

1617

18+
@DeveloperAPI
1719
class Stats:
1820
"""A container class holding a number of values and executing reductions over them.
1921

rllib/utils/metrics/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import re
2+
3+
4+
def to_snake_case(class_name: str) -> str:
5+
"""Converts class name to snake case.
6+
7+
This is used to unify metrics names when using class names within.
8+
Args:
9+
class_name: A string defining a class name (usually in camel
10+
case).
11+
12+
Returns:
13+
The class name in snake case.
14+
"""
15+
return re.sub(r"(?<!^)(?=[A-Z])", "_", class_name).lower()

0 commit comments

Comments
 (0)