3030from ray .rllib .utils .deprecation import Deprecated
3131from ray .rllib .utils .framework import get_device , try_import_torch
3232from ray .rllib .utils .metrics import (
33+ ENV_TO_MODULE_CONNECTOR ,
3334 EPISODE_DURATION_SEC_MEAN ,
3435 EPISODE_LEN_MAX ,
3536 EPISODE_LEN_MEAN ,
3637 EPISODE_LEN_MIN ,
3738 EPISODE_RETURN_MAX ,
3839 EPISODE_RETURN_MEAN ,
3940 EPISODE_RETURN_MIN ,
41+ MODULE_TO_ENV_CONNECTOR ,
4042 NUM_AGENT_STEPS_SAMPLED ,
4143 NUM_AGENT_STEPS_SAMPLED_LIFETIME ,
4244 NUM_ENV_STEPS_SAMPLED ,
4547 NUM_EPISODES_LIFETIME ,
4648 NUM_MODULE_STEPS_SAMPLED ,
4749 NUM_MODULE_STEPS_SAMPLED_LIFETIME ,
50+ RLMODULE_INFERENCE_TIMER ,
4851 SAMPLE_TIMER ,
4952 TIME_BETWEEN_SAMPLING ,
5053 WEIGHTS_SEQ_NO ,
5154)
52- from ray .rllib .utils .metrics .metrics_logger import MetricsLogger
5355from ray .rllib .utils .pre_checks .env import check_multiagent_environments
5456from ray .rllib .utils .typing import EpisodeID , ModelWeights , ResultDict , StateDict
5557from ray .tune .registry import ENV_CREATOR , _global_registry
@@ -88,8 +90,6 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
8890 self .worker_index : int = kwargs .get ("worker_index" )
8991 self .tune_trial_id : str = kwargs .get ("tune_trial_id" )
9092
91- # Set up all metrics-related structures and counters.
92- self .metrics : Optional [MetricsLogger ] = None
9393 self ._setup_metrics ()
9494
9595 # Create our callbacks object.
@@ -310,11 +310,13 @@ def _sample(
310310 self .metrics .peek (NUM_ENV_STEPS_SAMPLED_LIFETIME , default = 0 )
311311 + ts
312312 ) * (self .config .num_env_runners or 1 )
313- to_env = self .module .forward_exploration (
314- to_module , t = global_env_steps_lifetime
315- )
313+ with self .metrics .log_time (RLMODULE_INFERENCE_TIMER ):
314+ to_env = self .module .forward_exploration (
315+ to_module , t = global_env_steps_lifetime
316+ )
316317 else :
317- to_env = self .module .forward_inference (to_module )
318+ with self .metrics .log_time (RLMODULE_INFERENCE_TIMER ):
319+ to_env = self .module .forward_inference (to_module )
318320
319321 # Module-to-env connector.
320322 to_env = self ._module_to_env (
@@ -324,6 +326,7 @@ def _sample(
324326 explore = explore ,
325327 shared_data = shared_data ,
326328 metrics = self .metrics ,
329+ metrics_prefix_key = (MODULE_TO_ENV_CONNECTOR ,),
327330 )
328331 # In case all environments had been terminated `to_module` will be
329332 # empty and no actions are needed b/c we reset all environemnts.
@@ -453,19 +456,23 @@ def _sample(
453456 # Run the env-to-module connector pipeline for all done episodes.
454457 # Note, this is needed to postprocess last-step data, e.g. if the
455458 # user uses a connector that one-hot encodes observations.
459+ # Note, this pipeline run is not timed as the number of episodes
460+ # can differ from `num_envs_per_env_runner` and would bias time
461+ # measurements.
456462 self ._env_to_module (
457463 episodes = done_episodes_to_run_env_to_module ,
458464 explore = explore ,
459465 rl_module = self .module ,
460466 shared_data = shared_data ,
461- metrics = self . metrics ,
467+ metrics = None ,
462468 )
463469 self ._cached_to_module = self ._env_to_module (
464470 episodes = episodes ,
465471 explore = explore ,
466472 rl_module = self .module ,
467473 shared_data = shared_data ,
468474 metrics = self .metrics ,
475+ metrics_prefix_key = (ENV_TO_MODULE_CONNECTOR ,),
469476 )
470477
471478 # Numpy'ize the done episodes after running the connector pipeline. Note,
@@ -544,6 +551,7 @@ def _reset_envs(self, episodes, shared_data, explore):
544551 explore = explore ,
545552 shared_data = shared_data ,
546553 metrics = self .metrics ,
554+ metrics_key_prefix = (ENV_TO_MODULE_CONNECTOR ,),
547555 )
548556
549557 # Call `on_episode_start()` callbacks (always after reset).
@@ -871,8 +879,6 @@ def stop(self):
871879 self .env .close ()
872880
873881 def _setup_metrics (self ):
874- self .metrics = MetricsLogger ()
875-
876882 self ._done_episodes_for_metrics : List [MultiAgentEpisode ] = []
877883 self ._ongoing_episodes_for_metrics : DefaultDict [
878884 EpisodeID , List [MultiAgentEpisode ]
0 commit comments