Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds infos to EpisodeData, tests for infos, and adds standards for infos to documentation #132

Merged
merged 29 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
edc00e2
initial commit for info's fix
balisujohn Aug 21, 2023
2d892c2
tentative draft of info support for EpisodeData
balisujohn Aug 25, 2023
08d708a
typing fix
balisujohn Aug 25, 2023
5dce7f4
typing fixed, removed print
balisujohn Aug 25, 2023
e93cfad
added some information to dataset stabdards started work on test for …
balisujohn Aug 28, 2023
3e05c49
added tests for error in response to infos with timestep variant stru…
balisujohn Aug 30, 2023
214052d
added explicit np.array dtype support, documentation, and tests
balisujohn Sep 3, 2023
2693fd1
updated doc page
balisujohn Sep 4, 2023
961c534
table syntax change
balisujohn Sep 4, 2023
a23d434
remove print
Nov 28, 2023
8dfe498
DataCollectorV0 -> DataCollector
Nov 28, 2023
02e8dd2
rename test
Nov 29, 2023
e84433b
move info shape check to add to buffer
Nov 30, 2023
fca36c1
add _get_info in dummy test envs
Nov 30, 2023
701b8e1
_get_info_at_step_index
Nov 30, 2023
cf35cd7
fix pre-commit
Dec 7, 2023
4cd3d4d
fix tests
younik Jan 20, 2024
14c1630
fix docs
younik Jan 20, 2024
4ed8612
fix pre-commit
younik Jan 20, 2024
ebbce77
refactor
younik Jan 20, 2024
d39aaca
simplify tests
younik Jan 21, 2024
0983ec2
fix pre-commit
younik Jan 21, 2024
dd252e9
remove redundant comments
younik Jan 21, 2024
7661f20
fix pre-commit
younik Jan 21, 2024
d42879e
fixes
younik Jan 22, 2024
348513e
fix basic_usage
younik Jan 25, 2024
5dd2de5
fix episode_data repr
younik Jan 25, 2024
0da6a84
fix common
younik Jan 25, 2024
9578d84
improe tests
younik Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/content/basic_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ for _ in range(total_episodes):
if terminated or truncated:
break

dataset = env.create_dataset(dataset_id="CartPole-v1-test-v0",
dataset = env.create_dataset(dataset_id="cartpole-test-v0",
algorithm_name="Random-Policy",
code_permalink="https://github.com/Farama-Foundation/Minari",
author="Farama",
Expand All @@ -96,7 +96,7 @@ Once the dataset has been created we can check if the Minari dataset id appears
>>> import minari
>>> local_datasets = minari.list_local_datasets()
>>> local_datasets.keys()
dict_keys(['CartPole-v1-test-v0'])
dict_keys(['cartpole-test-v0'])
```

```{eval-rst}
Expand Down Expand Up @@ -125,7 +125,7 @@ env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, max_buffer_steps=100000)

total_episodes = 100
dataset_name = "CartPole-v1-test-v0"
dataset_name = "cartpole-test-v0"
younik marked this conversation as resolved.
Show resolved Hide resolved
dataset = None
if dataset_name in minari.list_local_datasets():
dataset = minari.load_dataset(dataset_name)
Expand Down Expand Up @@ -161,9 +161,9 @@ Minari will only be able to load datasets that are stored in your `local root di

```python
>>> import minari
>>> dataset = minari.load_dataset('CartPole-v1-test-v0')
>>> dataset = minari.load_dataset('cartpole-test-v0')
>>> dataset.name
'CartPole-v1-test-v0'
'cartpole-test-v0'
```

### Download Remote Datasets
Expand Down Expand Up @@ -323,7 +323,7 @@ From a :class:`minari.MinariDataset` object we can also recover the Gymnasium en
```python
import minari

dataset = minari.load_dataset('CartPole-v1-test-v0')
dataset = minari.load_dataset('cartpole-test-v0')
env = dataset.recover_environment()

env.reset()
Expand Down
7 changes: 7 additions & 0 deletions docs/content/dataset_standards.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,5 +554,12 @@ The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, eac
| `rewards` | `np.ndarray` | Rewards for each timestep. |
| `terminations` | `np.ndarray` | Terminations for each timestep. |
| `truncations` | `np.ndarray` | Truncations for each timestep. |
| `infos` | `dict` | A dictionary containing additional information. |

As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used.

## Additional Information Formatting

When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, an info dict must be provided from every call to the environment's `step` and `reset` function. The structure of the info dictionary must be the same across timesteps.

Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py.
86 changes: 59 additions & 27 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
)

self._record_infos = record_infos
self._reference_info = None
self.max_buffer_steps = max_buffer_steps

# Initialzie empty buffer
Expand All @@ -136,11 +137,11 @@ def __init__(
self._step_id = -1
self._episode_id = -1

def _add_to_episode_buffer(
def _add_step_data(
self,
episode_buffer: EpisodeBuffer,
step_data: Union[StepData, Dict[str, StepData]],
) -> EpisodeBuffer:
step_data: Union[StepData, Dict],
):
"""Add step data dictionary to episode buffer.

Args:
Expand All @@ -150,31 +151,43 @@ def _add_to_episode_buffer(
Returns:
Dict: new dictionary episode buffer with added values from step_data
"""
dict_data = dict(step_data)
if not self._record_infos:
dict_data = {k: v for k, v in step_data.items() if k != "infos"}
else:
assert self._reference_info is not None
if not _check_infos_same_shape(
self._reference_info, step_data["infos"]
):
raise ValueError(
"Info structure inconsistent with info structure returned by original reset."
)

self._add_to_episode_buffer(episode_buffer, dict_data)

def _add_to_episode_buffer(
self,
episode_buffer: EpisodeBuffer,
step_data: Dict[str, Any],
):
for key, value in step_data.items():
if (not self._record_infos and key == "infos") or (value is None):
if value is None:
continue

if key not in episode_buffer:
if isinstance(value, dict):
episode_buffer[key] = self._add_to_episode_buffer({}, value)
else:
episode_buffer[key] = [value]
episode_buffer[key] = {} if isinstance(value, dict) else []

if isinstance(value, dict):
assert isinstance(
episode_buffer[key], dict
), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}"

self._add_to_episode_buffer(episode_buffer[key], value)
else:
if isinstance(value, dict):
assert isinstance(
episode_buffer[key], dict
), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}"

episode_buffer[key] = self._add_to_episode_buffer(
episode_buffer[key], value
)
else:
assert isinstance(
episode_buffer[key], list
), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}"
episode_buffer[key].append(value)

return episode_buffer
assert isinstance(
episode_buffer[key], list
), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}"
episode_buffer[key].append(value)

def step(
self, action: ActType
Expand All @@ -191,6 +204,9 @@ def step(
terminated=terminated,
truncated=truncated,
)

# Force step data dictionary to include keys corresponding to Gymnasium step returns:
# actions, observations, rewards, terminations, truncations, and infos
assert STEP_DATA_KEYS.issubset(
step_data.keys()
), "One or more required keys is missing from 'step-data'."
Expand All @@ -203,7 +219,7 @@ def step(
), "Actions are not in action space."

self._step_id += 1
self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data)
self._add_step_data(self._buffer[-1], step_data)

if (
self.max_buffer_steps is not None
Expand All @@ -219,7 +235,7 @@ def step(
"observations": step_data["observations"],
"infos": step_data["infos"],
}
eps_buff = self._add_to_episode_buffer(eps_buff, previous_data)
self._add_step_data(eps_buff, previous_data)
self._buffer.append(eps_buff)

return obs, rew, terminated, truncated, info
Expand All @@ -245,14 +261,17 @@ def reset(
observation (ObsType): Observation of the initial state.
info (dictionary): Auxiliary information complementing ``observation``.
"""
autoseed_enabled = (not options) or options.get("minari_autoseed", False)
autoseed_enabled = (not options) or options.get("minari_autoseed", True)
if seed is None and autoseed_enabled:
seed = secrets.randbits(AUTOSEED_BIT_SIZE)

obs, info = self.env.reset(seed=seed, options=options)
step_data = self._step_data_callback(env=self.env, obs=obs, info=info)
self._episode_id += 1

if self._record_infos and self._reference_info is None:
self._reference_info = step_data["infos"]

assert STEP_DATA_KEYS.issubset(
step_data.keys()
), "One or more required keys is missing from 'step-data'"
Expand All @@ -262,7 +281,7 @@ def reset(
"seed": str(None) if seed is None else seed,
"id": self._episode_id
}
episode_buffer = self._add_to_episode_buffer(episode_buffer, step_data)
self._add_step_data(episode_buffer, step_data)
self._buffer.append(episode_buffer)
return obs, info

Expand Down Expand Up @@ -418,3 +437,16 @@ def close(self):

self._buffer.clear()
shutil.rmtree(self._tmp_dir.name)


def _check_infos_same_shape(info_1: dict, info_2: dict):
if info_1.keys() != info_2.keys():
return False
for key in info_1.keys():
if type(info_1[key]) is not type(info_2[key]):
return False
if isinstance(info_1[key], dict):
return _check_infos_same_shape(info_1[key], info_2[key])
elif isinstance(info_1[key], np.ndarray):
return (info_1[key].shape == info_2[key].shape) and (info_1[key].dtype == info_2[key].dtype)
return True
4 changes: 3 additions & 1 deletion minari/dataset/episode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class EpisodeData:
rewards: np.ndarray
terminations: np.ndarray
truncations: np.ndarray
infos: dict

def __repr__(self) -> str:
return (
Expand All @@ -30,7 +31,8 @@ def __repr__(self) -> str:
f"actions={EpisodeData._repr_space_values(self.actions)}, "
f"rewards=ndarray of {len(self.rewards)} floats, "
f"terminations=ndarray of {len(self.terminations)} bools, "
f"truncations=ndarray of {len(self.truncations)} bools"
f"truncations=ndarray of {len(self.truncations)} bools, "
f"infos=dict with the following keys: {list(self.infos.keys())}"
")"
)

Expand Down
21 changes: 20 additions & 1 deletion minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def new(
obj._action_space = action_space

if env_spec is not None:
metadata["env_spec"] = env_spec.to_json()
try:
metadata["env_spec"] = env_spec.to_json()
except TypeError:
pass
with h5py.File(obj._file_path, "a") as file:
file.attrs.update(metadata)
return obj
Expand Down Expand Up @@ -161,6 +164,19 @@ def apply(
ep_dicts = self.get_episodes(episode_indices)
return map(function, ep_dicts)

def _decode_infos(self, infos: h5py.Group):
younik marked this conversation as resolved.
Show resolved Hide resolved
result = {}
for key in infos.keys():
if isinstance(infos[key], h5py.Group):
result[key] = self._decode_infos(infos[key])
elif isinstance(infos[key], h5py.Dataset):
result[key] = infos[key][()]
else:
raise ValueError(
"Infos are in an unsupported format; see Minari documentation for supported formats."
)
return result

def _decode_space(
self,
hdf_ref: Union[h5py.Group, h5py.Dataset, h5py.Datatype],
Expand Down Expand Up @@ -219,6 +235,9 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
"actions": self._decode_space(
ep_group["actions"], self.action_space
),
"infos": self._decode_infos(ep_group["infos"])
if "infos" in ep_group
else {},
}
for key in {"rewards", "terminations", "truncations"}:
group_value = ep_group[key]
Expand Down
Loading
Loading