Skip to content

Commit 3b94e5f

Browse files
[RLlib; Offline RL] BC performance improvements and adjustments to new Learner.update logic. (ray-project#51425)
1 parent 8773682 commit 3b94e5f

File tree

13 files changed

+395
-127
lines changed

13 files changed

+395
-127
lines changed

rllib/algorithms/bc/bc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ def build_learner_connector(
9797
pipeline.remove("AddOneTsToEpisodesAndTruncate")
9898
pipeline.remove("GeneralAdvantageEstimation")
9999

100+
# In case we run multiple updates per RLlib training step in the `Learner` or
101+
# when training on GPU conversion to tensors is managed in batch prefetching.
102+
if self.num_gpus_per_learner > 0 or (
103+
self.dataset_num_iters_per_learner
104+
and self.dataset_num_iters_per_learner > 1
105+
):
106+
pipeline.remove("NumpyToTensor")
107+
100108
return pipeline
101109

102110
@override(MARWILConfig)

rllib/algorithms/cql/cql.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,12 @@ def build_learner_connector(
212212
AddNextObservationsFromEpisodesToTrainBatch(),
213213
)
214214

215-
# If training on GPU, do not convert batches to tensors.
216-
if self.num_gpus_per_learner > 0:
215+
# In case we run multiple updates per RLlib training step in the `Learner` or
216+
# when training on GPU conversion to tensors is managed in batch prefetching.
217+
if self.num_gpus_per_learner > 0 or (
218+
self.dataset_num_iters_per_learner
219+
and self.dataset_num_iters_per_learner > 1
220+
):
217221
pipeline.remove("NumpyToTensor")
218222

219223
return pipeline

rllib/algorithms/marwil/marwil.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,12 @@ def build_learner_connector(
377377

378378
# If training on GPU, convert batches to `numpy` arrays to load them
379379
# on GPU in the `Learner`.
380-
if self.num_gpus_per_learner > 0:
380+
# In case we run multiple updates per RLlib training step in the `Learner` or
381+
# when training on GPU conversion to tensors is managed in batch prefetching.
382+
if self.num_gpus_per_learner > 0 or (
383+
self.dataset_num_iters_per_learner
384+
and self.dataset_num_iters_per_learner > 1
385+
):
381386
pipeline.insert_after(GeneralAdvantageEstimation, TensorToNumpy())
382387

383388
return pipeline

rllib/algorithms/marwil/tests/test_marwil.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY
1111
from ray.rllib.env import INPUT_ENV_SPACES
1212
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
13+
from ray.rllib.utils import unflatten_dict
1314
from ray.rllib.utils.framework import try_import_torch
1415
from ray.rllib.utils.test_utils import check
1516

@@ -172,7 +173,7 @@ def test_marwil_loss_function(self):
172173
)
173174
# Note, for `ray.data`'s pipeline everything has to be a dictionary
174175
# therefore the batch is embedded into another dictionary.
175-
batch = offline_prelearner(batch)["batch"][0]
176+
batch = unflatten_dict(offline_prelearner(batch))
176177
if Columns.LOSS_MASK in batch[DEFAULT_MODULE_ID]:
177178
loss_mask = (
178179
batch[DEFAULT_MODULE_ID][Columns.LOSS_MASK].detach().cpu().numpy()

rllib/core/learner/learner.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020

2121
import ray
22-
from ray.data.iterator import DataIterator
2322
from ray.rllib.connectors.learner.learner_connector_pipeline import (
2423
LearnerConnectorPipeline,
2524
)
@@ -37,6 +36,7 @@
3736
MultiRLModuleSpec,
3837
)
3938
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
39+
from ray.rllib.utils import unflatten_dict
4040
from ray.rllib.policy.policy import PolicySpec
4141
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
4242
from ray.rllib.utils.annotations import (
@@ -269,7 +269,7 @@ def __init__(
269269

270270
# In case of offline learning and multiple learners, each learner receives a
271271
# repeatable iterator that iterates over a split of the streamed data.
272-
self.iterator: DataIterator = None
272+
self.iterator: MiniBatchRayDataIterator = None
273273

274274
# TODO (sven): Do we really need this API? It seems like LearnerGroup constructs
275275
# all Learner workers and then immediately builds them any ways? Unless there is
@@ -727,7 +727,13 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:
727727
"""
728728

729729
@abc.abstractmethod
730-
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
730+
def _convert_batch_type(
731+
self,
732+
batch: MultiAgentBatch,
733+
to_device: bool = False,
734+
pin_memory: bool = False,
735+
use_stream: bool = False,
736+
) -> MultiAgentBatch:
731737
"""Converts the elements of a MultiAgentBatch to Tensors on the correct device.
732738
733739
Args:
@@ -1041,33 +1047,36 @@ def update(
10411047
"Learner.update(data_iterators=..) requires `num_iters` kwarg!"
10421048
)
10431049

1050+
def _collate_fn(_batch: Dict[str, numpy.ndarray]) -> MultiAgentBatch:
1051+
_batch = unflatten_dict(_batch)
1052+
_batch = MultiAgentBatch(
1053+
{
1054+
module_id: SampleBatch(module_data)
1055+
for module_id, module_data in _batch.items()
1056+
},
1057+
env_steps=sum(
1058+
len(next(iter(module_data.values())))
1059+
for module_data in _batch.values()
1060+
),
1061+
)
1062+
_batch = self._convert_batch_type(_batch, to_device=False)
1063+
return self._set_slicing_by_batch_id(_batch, value=True)
1064+
1065+
def _finalize_fn(batch: MultiAgentBatch) -> MultiAgentBatch:
1066+
return self._convert_batch_type(batch, to_device=True, use_stream=True)
1067+
10441068
if not self.iterator:
1045-
self.iterator = training_data.data_iterators[0]
1046-
1047-
def _finalize_fn(_batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
1048-
# Note, the incoming batch is a dictionary with a numpy array
1049-
# holding the `MultiAgentBatch`.
1050-
_batch = self._convert_batch_type(_batch["batch"][0])
1051-
return {"batch": self._set_slicing_by_batch_id(_batch, value=True)}
1052-
1053-
batch_iter = MiniBatchRayDataIterator(
1054-
iterator=self.iterator,
1055-
finalize_fn=_finalize_fn,
1056-
num_iters=num_iters,
1057-
**kwargs,
1058-
)
1059-
# Record the number of batches pulled from the dataset.
1060-
self.metrics.log_value(
1061-
(ALL_MODULES, DATASET_NUM_ITERS_TRAINED),
1062-
num_iters,
1063-
reduce="sum",
1064-
clear_on_reduce=True,
1065-
)
1066-
self.metrics.log_value(
1067-
(ALL_MODULES, DATASET_NUM_ITERS_TRAINED_LIFETIME),
1068-
num_iters,
1069-
reduce="sum",
1070-
)
1069+
# This iterator holds a `ray.data.DataIterator` and manages it state.
1070+
self.iterator = MiniBatchRayDataIterator(
1071+
iterator=training_data.data_iterators[0],
1072+
collate_fn=_collate_fn,
1073+
finalize_fn=_finalize_fn,
1074+
minibatch_size=minibatch_size,
1075+
num_iters=num_iters,
1076+
**kwargs,
1077+
)
1078+
1079+
batch_iter = self.iterator
10711080
else:
10721081
batch = self._make_batch_if_necessary(training_data=training_data)
10731082
assert batch is not None
@@ -1104,7 +1113,7 @@ def _finalize_fn(_batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
11041113
)
11051114

11061115
# Perform the actual looping through the minibatches or the given data iterator.
1107-
for tensor_minibatch in batch_iter:
1116+
for iteration, tensor_minibatch in enumerate(batch_iter):
11081117
# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
11091118
# found in this batch. If not, throw an error.
11101119
unknown_module_ids = set(tensor_minibatch.policy_batches.keys()) - set(
@@ -1133,6 +1142,19 @@ def _finalize_fn(_batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
11331142

11341143
self._set_slicing_by_batch_id(tensor_minibatch, value=False)
11351144

1145+
if self.iterator:
1146+
# Record the number of batches pulled from the dataset.
1147+
self.metrics.log_value(
1148+
(ALL_MODULES, DATASET_NUM_ITERS_TRAINED),
1149+
iteration + 1,
1150+
reduce="sum",
1151+
clear_on_reduce=True,
1152+
)
1153+
self.metrics.log_value(
1154+
(ALL_MODULES, DATASET_NUM_ITERS_TRAINED_LIFETIME),
1155+
iteration + 1,
1156+
reduce="sum",
1157+
)
11361158
# Log all individual RLModules' loss terms and its registered optimizers'
11371159
# current learning rates.
11381160
# Note: We do this only once for the last of the minibatch updates, b/c the

rllib/core/learner/torch/torch_learner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,19 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]:
366366
return list(module.parameters())
367367

368368
@override(Learner)
369-
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
370-
batch = convert_to_torch_tensor(batch.policy_batches, device=self._device)
369+
def _convert_batch_type(
370+
self,
371+
batch: MultiAgentBatch,
372+
to_device: bool = True,
373+
pin_memory: bool = False,
374+
use_stream: bool = False,
375+
) -> MultiAgentBatch:
376+
batch = convert_to_torch_tensor(
377+
batch.policy_batches,
378+
device=self._device if to_device else None,
379+
pin_memory=pin_memory,
380+
use_stream=use_stream,
381+
)
371382
# TODO (sven): This computation of `env_steps` is not accurate!
372383
length = max(len(b) for b in batch.values())
373384
batch = MultiAgentBatch(batch, env_steps=length)

rllib/offline/offline_data.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import logging
22
from pathlib import Path
33
import pyarrow.fs
4+
import numpy as np
45
import ray
56
import time
67
import types
78

9+
from typing import Dict
10+
811
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
912
from ray.rllib.core import COMPONENT_RL_MODULE
1013
from ray.rllib.env import INPUT_ENV_SPACES
1114
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
15+
from ray.rllib.utils import unflatten_dict
16+
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
1217
from ray.rllib.utils import force_list
1318
from ray.rllib.utils.annotations import (
1419
OverrideToImplementCustomLogic,
@@ -223,13 +228,26 @@ def sample(
223228
self.batch_iterators = self.data.iterator()
224229
# Otherwise, the user wants batches returned.
225230
else:
231+
# Define a collate (last-mile) transformation that maps batches
232+
# to RLlib's `MultiAgentBatch`.
233+
def _collate_fn(_batch: Dict[str, np.ndarray]) -> MultiAgentBatch:
234+
_batch = unflatten_dict(_batch)
235+
return MultiAgentBatch(
236+
{
237+
module_id: SampleBatch(module_data)
238+
for module_id, module_data in _batch.items()
239+
},
240+
env_steps=sum(
241+
len(next(iter(module_data.values())))
242+
for module_data in _batch.values()
243+
),
244+
)
245+
226246
# If no iterator should be returned, or if we want to return a single
227247
# batch iterator, we instantiate the batch iterator once, here.
228248
self.batch_iterators = self.data.iter_batches(
229-
# This is important. The batch size is now 1, because the data
230-
# is already run through the `OfflinePreLearner` and a single
231-
# instance is a single `MultiAgentBatch` of size `num_samples`.
232-
batch_size=1,
249+
batch_size=num_samples,
250+
_collate_fn=_collate_fn,
233251
**self.iter_batches_kwargs,
234252
)
235253
self.batch_iterators = iter(self.batch_iterators)
@@ -240,7 +258,7 @@ def sample(
240258
else:
241259
# Return a single batch from the iterator.
242260
try:
243-
return next(self.batch_iterators)["batch"][0]
261+
return next(self.batch_iterators)
244262
except StopIteration:
245263
# If the batch iterator is exhausted, reinitiate a new one.
246264
logger.debug(

rllib/offline/offline_prelearner.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ray.rllib.core.columns import Columns
99
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec, MultiRLModule
1010
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
11-
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
11+
from ray.rllib.utils import flatten_dict
1212
from ray.rllib.utils.annotations import (
1313
OverrideToImplementCustomLogic,
1414
OverrideToImplementCustomLogic_CallToSuperRecommended,
@@ -137,7 +137,7 @@ def __init__(
137137
)
138138

139139
@OverrideToImplementCustomLogic
140-
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]:
140+
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
141141
"""Prepares plain data batches for training with `Learner`'s.
142142
143143
Args:
@@ -212,7 +212,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
212212
self._is_multi_agent,
213213
batch,
214214
schema=SCHEMA | self.config.input_read_schema,
215-
to_numpy=True,
215+
to_numpy=False,
216216
input_compress_columns=self.config.input_compress_columns,
217217
observation_space=self.observation_space,
218218
action_space=self.action_space,
@@ -255,28 +255,14 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
255255
# LearnerConnector pipeline.
256256
metrics=None,
257257
)
258-
# Convert to `MultiAgentBatch`.
259-
batch = MultiAgentBatch(
260-
{
261-
module_id: SampleBatch(module_data)
262-
for module_id, module_data in batch.items()
263-
},
264-
# TODO (simon): This can be run once for the batch and the
265-
# metrics, but we run it twice: here and later in the learner.
266-
env_steps=sum(e.env_steps() for e in episodes),
267-
)
268258
# Remove all data from modules that should not be trained. We do
269-
# not want to pass around more data than necessaty.
270-
for module_id in list(batch.policy_batches.keys()):
259+
# not want to pass around more data than necessary.
260+
for module_id in batch:
271261
if not self._should_module_be_updated(module_id, batch):
272-
del batch.policy_batches[module_id]
273-
274-
# TODO (simon): Log steps trained for metrics (how?). At best in learner
275-
# and not here. But we could precompute metrics here and pass it to the learner
276-
# for logging. Like this we do not have to pass around episode lists.
262+
del batch[module_id]
277263

278-
# TODO (simon): episodes are only needed for logging here.
279-
return {"batch": [batch]}
264+
# Flatten the dictionary to increase serialization performance.
265+
return flatten_dict(batch)
280266

281267
@property
282268
def default_prelearner_buffer_class(self) -> ReplayBuffer:

0 commit comments

Comments
 (0)