|
19 | 19 | ) |
20 | 20 |
|
21 | 21 | import ray |
22 | | -from ray.data.iterator import DataIterator |
23 | 22 | from ray.rllib.connectors.learner.learner_connector_pipeline import ( |
24 | 23 | LearnerConnectorPipeline, |
25 | 24 | ) |
|
37 | 36 | MultiRLModuleSpec, |
38 | 37 | ) |
39 | 38 | from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec |
| 39 | +from ray.rllib.utils import unflatten_dict |
40 | 40 | from ray.rllib.policy.policy import PolicySpec |
41 | 41 | from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch |
42 | 42 | from ray.rllib.utils.annotations import ( |
@@ -269,7 +269,7 @@ def __init__( |
269 | 269 |
|
270 | 270 | # In case of offline learning and multiple learners, each learner receives a |
271 | 271 | # repeatable iterator that iterates over a split of the streamed data. |
272 | | - self.iterator: DataIterator = None |
| 272 | + self.iterator: MiniBatchRayDataIterator = None |
273 | 273 |
|
274 | 274 | # TODO (sven): Do we really need this API? It seems like LearnerGroup constructs |
275 | 275 | # all Learner workers and then immediately builds them any ways? Unless there is |
@@ -727,7 +727,13 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]: |
727 | 727 | """ |
728 | 728 |
|
729 | 729 | @abc.abstractmethod |
730 | | - def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch: |
| 730 | + def _convert_batch_type( |
| 731 | + self, |
| 732 | + batch: MultiAgentBatch, |
| 733 | + to_device: bool = False, |
| 734 | + pin_memory: bool = False, |
| 735 | + use_stream: bool = False, |
| 736 | + ) -> MultiAgentBatch: |
731 | 737 | """Converts the elements of a MultiAgentBatch to Tensors on the correct device. |
732 | 738 |
|
733 | 739 | Args: |
@@ -1041,33 +1047,36 @@ def update( |
1041 | 1047 | "Learner.update(data_iterators=..) requires `num_iters` kwarg!" |
1042 | 1048 | ) |
1043 | 1049 |
|
| 1050 | + def _collate_fn(_batch: Dict[str, numpy.ndarray]) -> MultiAgentBatch: |
| 1051 | + _batch = unflatten_dict(_batch) |
| 1052 | + _batch = MultiAgentBatch( |
| 1053 | + { |
| 1054 | + module_id: SampleBatch(module_data) |
| 1055 | + for module_id, module_data in _batch.items() |
| 1056 | + }, |
| 1057 | + env_steps=sum( |
| 1058 | + len(next(iter(module_data.values()))) |
| 1059 | + for module_data in _batch.values() |
| 1060 | + ), |
| 1061 | + ) |
| 1062 | + _batch = self._convert_batch_type(_batch, to_device=False) |
| 1063 | + return self._set_slicing_by_batch_id(_batch, value=True) |
| 1064 | + |
| 1065 | + def _finalize_fn(batch: MultiAgentBatch) -> MultiAgentBatch: |
| 1066 | + return self._convert_batch_type(batch, to_device=True, use_stream=True) |
| 1067 | + |
1044 | 1068 | if not self.iterator: |
1045 | | - self.iterator = training_data.data_iterators[0] |
1046 | | - |
1047 | | - def _finalize_fn(_batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: |
1048 | | - # Note, the incoming batch is a dictionary with a numpy array |
1049 | | - # holding the `MultiAgentBatch`. |
1050 | | - _batch = self._convert_batch_type(_batch["batch"][0]) |
1051 | | - return {"batch": self._set_slicing_by_batch_id(_batch, value=True)} |
1052 | | - |
1053 | | - batch_iter = MiniBatchRayDataIterator( |
1054 | | - iterator=self.iterator, |
1055 | | - finalize_fn=_finalize_fn, |
1056 | | - num_iters=num_iters, |
1057 | | - **kwargs, |
1058 | | - ) |
1059 | | - # Record the number of batches pulled from the dataset. |
1060 | | - self.metrics.log_value( |
1061 | | - (ALL_MODULES, DATASET_NUM_ITERS_TRAINED), |
1062 | | - num_iters, |
1063 | | - reduce="sum", |
1064 | | - clear_on_reduce=True, |
1065 | | - ) |
1066 | | - self.metrics.log_value( |
1067 | | - (ALL_MODULES, DATASET_NUM_ITERS_TRAINED_LIFETIME), |
1068 | | - num_iters, |
1069 | | - reduce="sum", |
1070 | | - ) |
| 1069 | + # This iterator holds a `ray.data.DataIterator` and manages it state. |
| 1070 | + self.iterator = MiniBatchRayDataIterator( |
| 1071 | + iterator=training_data.data_iterators[0], |
| 1072 | + collate_fn=_collate_fn, |
| 1073 | + finalize_fn=_finalize_fn, |
| 1074 | + minibatch_size=minibatch_size, |
| 1075 | + num_iters=num_iters, |
| 1076 | + **kwargs, |
| 1077 | + ) |
| 1078 | + |
| 1079 | + batch_iter = self.iterator |
1071 | 1080 | else: |
1072 | 1081 | batch = self._make_batch_if_necessary(training_data=training_data) |
1073 | 1082 | assert batch is not None |
@@ -1104,7 +1113,7 @@ def _finalize_fn(_batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: |
1104 | 1113 | ) |
1105 | 1114 |
|
1106 | 1115 | # Perform the actual looping through the minibatches or the given data iterator. |
1107 | | - for tensor_minibatch in batch_iter: |
| 1116 | + for iteration, tensor_minibatch in enumerate(batch_iter): |
1108 | 1117 | # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs |
1109 | 1118 | # found in this batch. If not, throw an error. |
1110 | 1119 | unknown_module_ids = set(tensor_minibatch.policy_batches.keys()) - set( |
@@ -1133,6 +1142,19 @@ def _finalize_fn(_batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: |
1133 | 1142 |
|
1134 | 1143 | self._set_slicing_by_batch_id(tensor_minibatch, value=False) |
1135 | 1144 |
|
| 1145 | + if self.iterator: |
| 1146 | + # Record the number of batches pulled from the dataset. |
| 1147 | + self.metrics.log_value( |
| 1148 | + (ALL_MODULES, DATASET_NUM_ITERS_TRAINED), |
| 1149 | + iteration + 1, |
| 1150 | + reduce="sum", |
| 1151 | + clear_on_reduce=True, |
| 1152 | + ) |
| 1153 | + self.metrics.log_value( |
| 1154 | + (ALL_MODULES, DATASET_NUM_ITERS_TRAINED_LIFETIME), |
| 1155 | + iteration + 1, |
| 1156 | + reduce="sum", |
| 1157 | + ) |
1136 | 1158 | # Log all individual RLModules' loss terms and its registered optimizers' |
1137 | 1159 | # current learning rates. |
1138 | 1160 | # Note: We do this only once for the last of the minibatch updates, b/c the |
|
0 commit comments