Skip to content

Commit

Permalink
Adds infos to EpisodeData (#132)
Browse files Browse the repository at this point in the history
* initial commit for info's fix

* tentative draft of info support for EpisodeData

* typing fix

* typing fixed, removed print

* added some information to dataset stabdards started work on test for varying infos

* added tests for error in response to infos with timestep variant structure, and test of using StepDataCallback to fix it

* added explicit np.array dtype support, documentation, and tests

* updated doc page

* table syntax change

* remove print

* DataCollectorV0 -> DataCollector

* rename test

* move info shape check to add to buffer

* add _get_info in dummy test envs

* _get_info_at_step_index

* fix pre-commit

* fix tests

* fix docs

* fix pre-commit

* refactor

* simplify tests

* fix pre-commit

* remove redundant comments

* fix pre-commit

* fixes

* fix basic_usage

* fix episode_data repr

* fix common

* improe tests

---------

Co-authored-by: rodrigodelazcano <[email protected]>
Co-authored-by: Omar Younis <[email protected]>
  • Loading branch information
3 people authored Jan 25, 2024
1 parent 237ec51 commit 3e2dcc4
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 73 deletions.
12 changes: 6 additions & 6 deletions docs/content/basic_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ for _ in range(total_episodes):
if terminated or truncated:
break

dataset = env.create_dataset(dataset_id="CartPole-v1-test-v0",
dataset = env.create_dataset(dataset_id="cartpole-test-v0",
algorithm_name="Random-Policy",
code_permalink="https://github.com/Farama-Foundation/Minari",
author="Farama",
Expand All @@ -96,7 +96,7 @@ Once the dataset has been created we can check if the Minari dataset id appears
>>> import minari
>>> local_datasets = minari.list_local_datasets()
>>> local_datasets.keys()
dict_keys(['CartPole-v1-test-v0'])
dict_keys(['cartpole-test-v0'])
```

```{eval-rst}
Expand Down Expand Up @@ -125,7 +125,7 @@ env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, max_buffer_steps=100000)

total_episodes = 100
dataset_name = "CartPole-v1-test-v0"
dataset_name = "cartpole-test-v0"
dataset = None
if dataset_name in minari.list_local_datasets():
dataset = minari.load_dataset(dataset_name)
Expand Down Expand Up @@ -161,9 +161,9 @@ Minari will only be able to load datasets that are stored in your `local root di

```python
>>> import minari
>>> dataset = minari.load_dataset('CartPole-v1-test-v0')
>>> dataset = minari.load_dataset('cartpole-test-v0')
>>> dataset.name
'CartPole-v1-test-v0'
'cartpole-test-v0'
```

### Download Remote Datasets
Expand Down Expand Up @@ -323,7 +323,7 @@ From a :class:`minari.MinariDataset` object we can also recover the Gymnasium en
```python
import minari

dataset = minari.load_dataset('CartPole-v1-test-v0')
dataset = minari.load_dataset('cartpole-test-v0')
env = dataset.recover_environment()

env.reset()
Expand Down
7 changes: 7 additions & 0 deletions docs/content/dataset_standards.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,5 +554,12 @@ The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, eac
| `rewards` | `np.ndarray` | Rewards for each timestep. |
| `terminations` | `np.ndarray` | Terminations for each timestep. |
| `truncations` | `np.ndarray` | Truncations for each timestep. |
| `infos` | `dict` | A dictionary containing additional information. |

As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used.

## Additional Information Formatting

When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, an info dict must be provided from every call to the environment's `step` and `reset` function. The structure of the info dictionary must be the same across timesteps.

Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py.
86 changes: 59 additions & 27 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
)

self._record_infos = record_infos
self._reference_info = None
self.max_buffer_steps = max_buffer_steps

# Initialzie empty buffer
Expand All @@ -136,11 +137,11 @@ def __init__(
self._step_id = -1
self._episode_id = -1

def _add_to_episode_buffer(
def _add_step_data(
self,
episode_buffer: EpisodeBuffer,
step_data: Union[StepData, Dict[str, StepData]],
) -> EpisodeBuffer:
step_data: Union[StepData, Dict],
):
"""Add step data dictionary to episode buffer.
Args:
Expand All @@ -150,31 +151,43 @@ def _add_to_episode_buffer(
Returns:
Dict: new dictionary episode buffer with added values from step_data
"""
dict_data = dict(step_data)
if not self._record_infos:
dict_data = {k: v for k, v in step_data.items() if k != "infos"}
else:
assert self._reference_info is not None
if not _check_infos_same_shape(
self._reference_info, step_data["infos"]
):
raise ValueError(
"Info structure inconsistent with info structure returned by original reset."
)

self._add_to_episode_buffer(episode_buffer, dict_data)

def _add_to_episode_buffer(
self,
episode_buffer: EpisodeBuffer,
step_data: Dict[str, Any],
):
for key, value in step_data.items():
if (not self._record_infos and key == "infos") or (value is None):
if value is None:
continue

if key not in episode_buffer:
if isinstance(value, dict):
episode_buffer[key] = self._add_to_episode_buffer({}, value)
else:
episode_buffer[key] = [value]
episode_buffer[key] = {} if isinstance(value, dict) else []

if isinstance(value, dict):
assert isinstance(
episode_buffer[key], dict
), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}"

self._add_to_episode_buffer(episode_buffer[key], value)
else:
if isinstance(value, dict):
assert isinstance(
episode_buffer[key], dict
), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}"

episode_buffer[key] = self._add_to_episode_buffer(
episode_buffer[key], value
)
else:
assert isinstance(
episode_buffer[key], list
), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}"
episode_buffer[key].append(value)

return episode_buffer
assert isinstance(
episode_buffer[key], list
), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}"
episode_buffer[key].append(value)

def step(
self, action: ActType
Expand All @@ -191,6 +204,9 @@ def step(
terminated=terminated,
truncated=truncated,
)

# Force step data dictionary to include keys corresponding to Gymnasium step returns:
# actions, observations, rewards, terminations, truncations, and infos
assert STEP_DATA_KEYS.issubset(
step_data.keys()
), "One or more required keys is missing from 'step-data'."
Expand All @@ -203,7 +219,7 @@ def step(
), "Actions are not in action space."

self._step_id += 1
self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data)
self._add_step_data(self._buffer[-1], step_data)

if (
self.max_buffer_steps is not None
Expand All @@ -219,7 +235,7 @@ def step(
"observations": step_data["observations"],
"infos": step_data["infos"],
}
eps_buff = self._add_to_episode_buffer(eps_buff, previous_data)
self._add_step_data(eps_buff, previous_data)
self._buffer.append(eps_buff)

return obs, rew, terminated, truncated, info
Expand All @@ -245,14 +261,17 @@ def reset(
observation (ObsType): Observation of the initial state.
info (dictionary): Auxiliary information complementing ``observation``.
"""
autoseed_enabled = (not options) or options.get("minari_autoseed", False)
autoseed_enabled = (not options) or options.get("minari_autoseed", True)
if seed is None and autoseed_enabled:
seed = secrets.randbits(AUTOSEED_BIT_SIZE)

obs, info = self.env.reset(seed=seed, options=options)
step_data = self._step_data_callback(env=self.env, obs=obs, info=info)
self._episode_id += 1

if self._record_infos and self._reference_info is None:
self._reference_info = step_data["infos"]

assert STEP_DATA_KEYS.issubset(
step_data.keys()
), "One or more required keys is missing from 'step-data'"
Expand All @@ -262,7 +281,7 @@ def reset(
"seed": str(None) if seed is None else seed,
"id": self._episode_id
}
episode_buffer = self._add_to_episode_buffer(episode_buffer, step_data)
self._add_step_data(episode_buffer, step_data)
self._buffer.append(episode_buffer)
return obs, info

Expand Down Expand Up @@ -418,3 +437,16 @@ def close(self):

self._buffer.clear()
shutil.rmtree(self._tmp_dir.name)


def _check_infos_same_shape(info_1: dict, info_2: dict):
if info_1.keys() != info_2.keys():
return False
for key in info_1.keys():
if type(info_1[key]) is not type(info_2[key]):
return False
if isinstance(info_1[key], dict):
return _check_infos_same_shape(info_1[key], info_2[key])
elif isinstance(info_1[key], np.ndarray):
return (info_1[key].shape == info_2[key].shape) and (info_1[key].dtype == info_2[key].dtype)
return True
4 changes: 3 additions & 1 deletion minari/dataset/episode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class EpisodeData:
rewards: np.ndarray
terminations: np.ndarray
truncations: np.ndarray
infos: dict

def __repr__(self) -> str:
return (
Expand All @@ -30,7 +31,8 @@ def __repr__(self) -> str:
f"actions={EpisodeData._repr_space_values(self.actions)}, "
f"rewards=ndarray of {len(self.rewards)} floats, "
f"terminations=ndarray of {len(self.terminations)} bools, "
f"truncations=ndarray of {len(self.truncations)} bools"
f"truncations=ndarray of {len(self.truncations)} bools, "
f"infos=dict with the following keys: {list(self.infos.keys())}"
")"
)

Expand Down
21 changes: 20 additions & 1 deletion minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def new(
obj._action_space = action_space

if env_spec is not None:
metadata["env_spec"] = env_spec.to_json()
try:
metadata["env_spec"] = env_spec.to_json()
except TypeError:
pass
with h5py.File(obj._file_path, "a") as file:
file.attrs.update(metadata)
return obj
Expand Down Expand Up @@ -161,6 +164,19 @@ def apply(
ep_dicts = self.get_episodes(episode_indices)
return map(function, ep_dicts)

def _decode_infos(self, infos: h5py.Group):
result = {}
for key in infos.keys():
if isinstance(infos[key], h5py.Group):
result[key] = self._decode_infos(infos[key])
elif isinstance(infos[key], h5py.Dataset):
result[key] = infos[key][()]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result

def _decode_space(
self,
hdf_ref: Union[h5py.Group, h5py.Dataset, h5py.Datatype],
Expand Down Expand Up @@ -219,6 +235,9 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
"actions": self._decode_space(
ep_group["actions"], self.action_space
),
"infos": self._decode_infos(ep_group["infos"])
if "infos" in ep_group
else {},
}
for key in {"rewards", "terminations", "truncations"}:
group_value = ep_group[key]
Expand Down
Loading

0 comments on commit 3e2dcc4

Please sign in to comment.