Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigodelazcano committed Dec 7, 2023
1 parent 701b8e1 commit cf35cd7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 24 deletions.
1 change: 0 additions & 1 deletion minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _add_to_episode_buffer(
Returns:
Dict: new dictionary episode buffer with added values from step_data
"""

if self._record_infos and not self.check_infos_same_shape(
self._reference_info, step_data["infos"]
):
Expand Down
12 changes: 6 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,12 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]):
assert total_steps == data.total_steps


def assert_infos_same_shape(info_1, info_2):
def assert_infos_same_structure(info_1, info_2):
if len(info_1.keys()) != len(info_2.keys()):
return False
for key in info_1.keys():
if isinstance(info_1[key], dict):
if not assert_infos_same_shape(info_1[key], info_2[key]):
if not assert_infos_same_structure(info_1[key], info_2[key]):
return False
elif isinstance(info_1[key], np.ndarray):
if not (info_1[key].shape == info_2[key].shape) and (
Expand All @@ -601,11 +601,11 @@ def assert_infos_same_shape(info_1, info_2):
return True


def _get_info_at_step_index(infos, step_index):
def get_info_at_step_index(infos, step_index):
result = {}
for key in infos.keys():
if isinstance(infos[key], dict):
result[key] = _get_info_at_step_index(infos[key], step_index)
result[key] = get_info_at_step_index(infos[key], step_index)
elif isinstance(infos[key], np.ndarray):
result[key] = infos[key][step_index]
else:
Expand Down Expand Up @@ -732,8 +732,8 @@ def check_episode_data_integrity(
for i in range(episode.total_timesteps + 1):
obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i)
if info_sample is not None:
assert assert_infos_same_shape(
_get_info_at_step_index(episode.infos, i), info_sample
assert assert_infos_same_structure(
get_info_at_step_index(episode.infos, i), info_sample
)
assert observation_space.contains(obs)

Expand Down
23 changes: 6 additions & 17 deletions tests/data_collector/test_data_collector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import gymnasium as gym
import h5py
import numpy as np
import pytest

from minari import DataCollector, EpisodeData, MinariDataset, StepDataCallback
from tests.common import check_load_and_delete_dataset, register_dummy_envs
from tests.common import (
check_load_and_delete_dataset,
get_info_at_step_index,
register_dummy_envs,
)


MAX_UINT64 = np.iinfo(np.uint64).max
Expand All @@ -30,20 +33,6 @@ def __call__(self, env, **kwargs):
return step_data


def _get_step_from_infos(infos, step_index: int):
result = {}
for key in infos.keys():
if isinstance(infos[key], h5py.Group):
result[key] = _get_step_from_infos(infos[key], step_index)
elif isinstance(infos[key], h5py.Dataset):
result[key] = infos[key][step_index]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result


def _get_step_from_dictionary_space(episode_data, step_index):
step_data = {}
assert isinstance(episode_data, dict)
Expand Down Expand Up @@ -86,7 +75,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat
else:
action = episode.actions[index]

infos = _get_step_from_infos(episode.infos, index)
infos = get_info_at_step_index(episode.infos, index)

step_data = {
"id": episode.id,
Expand Down

0 comments on commit cf35cd7

Please sign in to comment.