From 15336fff839118498d4ed7cb4186815e904de432 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Sat, 22 Jul 2023 16:55:17 -0400 Subject: [PATCH 1/5] attempted to fix list_remote_datasets slowdown --- minari/storage/hosting.py | 40 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/minari/storage/hosting.py b/minari/storage/hosting.py index 942f6a7f..674055be 100644 --- a/minari/storage/hosting.py +++ b/minari/storage/hosting.py @@ -225,33 +225,29 @@ def list_remote_datasets( Dict[str, Dict[str, str]]: keys the names of the Minari datasets and values the metadata """ client = storage.Client.create_anonymous_client() - bucket = client.bucket("minari-datasets") - blobs = bucket.list_blobs(delimiter="main_data.hdf5") - - # Necessary to get prefixes iterable - next(blobs) + blobs = client.list_blobs(bucket_or_name="minari-datasets") # Generate dict = {'env_name-dataset_name': (version, metadata)} remote_datasets = {} - for prefix in sorted(blobs.prefixes): - blob = bucket.get_blob(prefix) + for blob in blobs: try: - metadata = blob.metadata - if compatible_minari_version and __version__ not in SpecifierSet( - metadata["minari_version"] - ): - continue - dataset_id = metadata["dataset_id"] - env_name, dataset_name, version = parse_dataset_id(dataset_id) - dataset = f"{env_name}-{dataset_name}" - if latest_version: - if ( - dataset not in remote_datasets - or version > remote_datasets[dataset][0] + if blob.name.endswith("main_data.hdf5"): + metadata = blob.metadata + if compatible_minari_version and __version__ not in SpecifierSet( + metadata["minari_version"] ): - remote_datasets[dataset] = (version, metadata) - else: - remote_datasets[dataset_id] = metadata + continue + dataset_id = metadata["dataset_id"] + env_name, dataset_name, version = parse_dataset_id(dataset_id) + dataset = f"{env_name}-{dataset_name}" + if latest_version: + if ( + dataset not in remote_datasets + or version > remote_datasets[dataset][0] + ): + remote_datasets[dataset] = (version, metadata) + else: + remote_datasets[dataset_id] = metadata except Exception: warnings.warn(f"Misconfigured dataset named {blob.name} on remote") From 0d07d90f8b6e491047d84681ffe46ce33d4d31ba Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Sun, 30 Jul 2023 18:40:05 -0400 Subject: [PATCH 2/5] tester.py draft --- tester.py | 262 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 tester.py diff --git a/tester.py b/tester.py new file mode 100644 index 00000000..99d8a6bb --- /dev/null +++ b/tester.py @@ -0,0 +1,262 @@ +import copy +from collections import OrderedDict +from typing import Dict +import datetime +import random +from operator import itemgetter + +import gymnasium as gym +import numpy as np +import pytest +from gymnasium import spaces +import pickle + +import minari +from minari import DataCollectorV0, MinariDataset +from tests.common import ( + register_dummy_envs, +) + + +NUM_EPISODES = 10000 +EPISODE_SAMPLE_COUNT = 10 + +register_dummy_envs() + + +def test_generate_dataset_with_collector_env(dataset_id, env_id): + """Test DataCollectorV0 wrapper and Minari dataset creation.""" + # dataset_id = "cartpole-test-v0" + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + env = gym.make(env_id) + + env = DataCollectorV0(env) + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + env.reset(seed=42) + + for episode in range(NUM_EPISODES): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: + assert not env._buffer[-1] + else: + assert env._buffer[-1] + + env.reset() + + # Create Minari dataset and store locally + dataset = minari.create_dataset_from_collector_env( + dataset_id=dataset_id, + collector_env=env, + algorithm_name="random_policy", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + author="WillDudley", + author_email="wdudley@farama.org", + ) + + + +def test_generate_dataset_with_external_buffer(dataset_id, env_id): + """Test create dataset from external buffers without using DataCollectorV0.""" + buffer = [] + # dataset_id = "cartpole-test-v0" + + + env = gym.make(env_id) + + observations = [] + actions = [] + rewards = [] + terminations = [] + truncations = [] + + + observation, info = env.reset(seed=42) + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + observation, _ = env.reset() + observations.append(observation) + for episode in range(NUM_EPISODES): + terminated = False + truncated = False + + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + observation, reward, terminated, truncated, _ = env.step(action) + observations.append(observation) + actions.append(action) + rewards.append(reward) + terminations.append(terminated) + truncations.append(truncated) + + episode_buffer = { + "observations": copy.deepcopy(observations), + "actions": copy.deepcopy(actions), + "rewards": np.asarray(rewards), + "terminations": np.asarray(terminations), + "truncations": np.asarray(truncations), + } + buffer.append(episode_buffer) + + observations.clear() + actions.clear() + rewards.clear() + terminations.clear() + truncations.clear() + + observation, _ = env.reset() + observations.append(observation) + + # Create Minari dataset and store locally + dataset = minari.create_dataset_from_buffers( + dataset_id=dataset_id, + env=env, + buffer=buffer, + algorithm_name="random_policy", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + author="WillDudley", + author_email="wdudley@farama.org", + ) + + + +def test_generate_dataset_pickle(dataset_id, env_id): + """Test create dataset from external buffers without using DataCollectorV0.""" + buffer = [] + # dataset_id = "cartpole-test-v0" + + + env = gym.make(env_id) + + observations = [] + actions = [] + rewards = [] + terminations = [] + truncations = [] + + + observation, info = env.reset(seed=42) + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + observation, _ = env.reset() + observations.append(observation) + for episode in range(NUM_EPISODES): + terminated = False + truncated = False + + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + observation, reward, terminated, truncated, _ = env.step(action) + observations.append(observation) + actions.append(action) + rewards.append(reward) + terminations.append(terminated) + truncations.append(truncated) + + episode_buffer = { + "observations": copy.deepcopy(observations), + "actions": copy.deepcopy(actions), + "rewards": np.asarray(rewards), + "terminations": np.asarray(terminations), + "truncations": np.asarray(truncations), + } + buffer.append(episode_buffer) + + observations.clear() + actions.clear() + rewards.clear() + terminations.clear() + truncations.clear() + + observation, _ = env.reset() + observations.append(observation) + + # Create Minari dataset and store locally with pickle + with open("test.pkl", "wb") as test_file: + pickle.dump(buffer,test_file) + + #with open("test.pkl", "rb") as test_file: + # test = pickle.load(test_file) + + +def test_sample_n_random_episodes_from_minari_dataset(dataset_id): + dataset = minari.load_dataset(dataset_id) + episodes = dataset.sample_episodes(EPISODE_SAMPLE_COUNT) + # print(episodes) + +def test_sample_n_random_episodes_from_pickle_dataset(): + with open("test.pkl", "rb") as test_file: + test = pickle.load(test_file) + + indices = random.sample(range(0,len(test)),EPISODE_SAMPLE_COUNT ) + + result = itemgetter(*indices)(test) + + + +def measure(function, args): + before = datetime.datetime.now() + function(*args) + after = datetime.datetime.now() + return (after-before).total_seconds() + + +if __name__ == "__main__": + + + environment_list = [ + ("cartpole-test-v0", "CartPole-v1"), + ("dummy-dict-test-v0", "DummyDictEnv-v0"), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), + ("dummy-text-test-v0", "DummyTextEnv-v0"), + ("dummy-combo-test-v0", "DummyComboEnv-v0"), + ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), + ] + + + measurements = {} + + + + for dataset_id, env_id in environment_list: + + #dataset_id, env_id = ("cartpole-test-v0", "CartPole-v1") + + + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + result = measure(test_generate_dataset_with_collector_env, (dataset_id, env_id)) + print(f"Time to generate {NUM_EPISODES} episodes with {env_id} using test_generate_dataset_with_collector_env: {str(result)}") + + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + + result = measure(test_generate_dataset_with_external_buffer, (dataset_id, env_id)) + print(f"Time to generate {NUM_EPISODES} episodes with {env_id} using test_generate_dataset_with_external_buffer: {str(result)}") + + + + result = measure(test_generate_dataset_pickle, (dataset_id, env_id)) + print(f"Time to generate {NUM_EPISODES} episodes with {env_id} using test_generate_dataset_pickle: {str(result)}") + + result = measure(test_sample_n_random_episodes_from_minari_dataset, (dataset_id,)) + print(f"Time to sample {EPISODE_SAMPLE_COUNT} episodes from {env_id} using test_sample_n_random_episodes_from_minari_dataset: {str(result)}") + + + result = measure(test_sample_n_random_episodes_from_pickle_dataset, ()) + print(f"Time to sample {EPISODE_SAMPLE_COUNT} episodes from {env_id} test_sample_n_random_episodes_from_pickle_dataset: {str(result)}") + \ No newline at end of file From 4c0f466a43e6f7cd2bb615ce4ded8061ce7b3866 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Mon, 31 Jul 2023 20:07:11 -0400 Subject: [PATCH 3/5] patch to speed up sampling from a minari dataset, MinariStorage total_steps tests coverage and bugfix --- minari/dataset/minari_dataset.py | 15 +++++++++------ minari/dataset/minari_storage.py | 12 +++++++++--- minari/utils.py | 6 +++--- tests/common.py | 5 +++++ 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 2cd443c2..7f69d72e 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -141,17 +141,20 @@ def __init__( self._additional_data_id = 0 if episode_indices is None: episode_indices = np.arange(self._data.total_episodes) + total_steps = self._data.total_steps + else: + total_steps = sum( + self._data.apply( + lambda episode: episode["total_timesteps"], + episode_indices=episode_indices, + ) + ) self._episode_indices = episode_indices assert self._episode_indices is not None - total_steps = sum( - self._data.apply( - lambda episode: episode["total_timesteps"], - episode_indices=self._episode_indices, - ) - ) + self.spec = MinariDatasetSpec( env_spec=self._data.env_spec, diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index c8fc2bb3..a508b07d 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -205,15 +205,18 @@ def update_from_collector_env( "id", last_episode_id + id ) + self._total_steps = file.attrs["total_steps"] + new_data_total_steps + # Update metadata of minari dataset file.attrs.modify( "total_episodes", last_episode_id + new_data_total_episodes ) file.attrs.modify( - "total_steps", file.attrs["total_steps"] + new_data_total_steps + "total_steps", self._total_steps ) self._total_episodes = int(file.attrs["total_episodes"].item()) + def update_from_buffer(self, buffer: List[dict], data_path: str): additional_steps = 0 with h5py.File(data_path, "a", track_order=True) as file: @@ -247,9 +250,12 @@ def update_from_buffer(self, buffer: List[dict], data_path: str): # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata - file.attrs.modify("total_episodes", last_episode_id + len(buffer)) + self._total_steps = file.attrs["total_steps"] + additional_steps + self._total_episodes = last_episode_id + len(buffer) + + file.attrs.modify("total_episodes", self._total_episodes) file.attrs.modify( - "total_steps", file.attrs["total_steps"] + additional_steps + "total_steps", self._total_steps ) self._total_episodes = int(file.attrs["total_episodes"].item()) diff --git a/minari/utils.py b/minari/utils.py index 025335a8..0945c705 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -467,9 +467,9 @@ def create_dataset_from_buffers( ) eps_group.attrs["id"] = i - total_steps = len(eps_buff["actions"]) - eps_group.attrs["total_steps"] = total_steps - total_steps += total_steps + episode_total_steps = len(eps_buff["actions"]) + eps_group.attrs["total_steps"] = episode_total_steps + total_steps += episode_total_steps if seed is None: eps_group.attrs["seed"] = str(None) diff --git a/tests/common.py b/tests/common.py index e4cdc1dc..4fc1dbf4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -461,8 +461,10 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): print([episode["id"] for episode in episodes]) # verify we have the right number of episodes, available at the right indices assert data.total_episodes == len(episodes) + total_steps = 0 # verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct for episode in episodes: + total_steps += episode["total_timesteps"] _check_space_elem( episode["observations"], data.observation_space, @@ -484,6 +486,9 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert episode["total_timesteps"] == len(episode["rewards"]) assert episode["total_timesteps"] == len(episode["terminations"]) assert episode["total_timesteps"] == len(episode["truncations"]) + print(total_steps) + print(data.total_steps) + assert total_steps == data.total_steps def _reconstuct_obs_or_action_at_index_recursive( From 8148589f0718aa70662cdf135d2d2c9aa1558bd1 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Thu, 10 Aug 2023 03:43:57 -0400 Subject: [PATCH 4/5] added basic trajectory sampling --- minari/dataset/minari_dataset.py | 56 +++++++++++++++++++++++++++- tests/dataset/test_minari_dataset.py | 40 ++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 7f69d72e..e2ea19d3 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -4,6 +4,9 @@ import re from dataclasses import dataclass, field from typing import Callable, Iterable, Iterator, List, Optional, Union +from random import choices, randrange, choice +from collections import Counter + import gymnasium as gym import numpy as np @@ -41,11 +44,15 @@ def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int]: return env_name, dataset_name, version + + @dataclass(frozen=True) class EpisodeData: """Contains the datasets data for a single episode. - This is the object returned by :class:`minari.MinariDataset.sample_episodes`. + This is the object returned by :class:`minari.MinariDataset.sample_episodes` and :class:`minari.MinariDataset.sample_trajectories`. + + In instances of `EpisodeData` returned by :class:`minari.MinariDataset.sample_trajectories`, `id` refers to the id of the starting episode. """ id: int @@ -236,6 +243,53 @@ def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]: episodes = self._data.get_episodes(indices) return list(map(lambda data: EpisodeData(**data), episodes)) + + def sample_trajectories(self, n_trajectories, trajectory_length, allow_restarts=False): + if allow_restarts: + total_steps = self.spec.total_steps + + starts = choices(0,total_steps-trajectory_length, k=n_trajectories) + assert False + else: + + + #We only want to load each episode once. We need to discard episodes that are too short for our trajectory length. + #We mark such episodes so they will not be sampled again, while always preserving uniform random sampling overal all + #samples not known to be invalid. + valid_episode_indices = {key:1 for key in range(0,self.spec.total_episodes)} + counts = {} + episodes = [] + samples = 0 + while samples < n_trajectories: + index = choice(list(valid_episode_indices.keys())) + if index in counts: + counts[index] += 1 + else: + sampled_episode = self._data.get_episodes([index])[0] + print(sampled_episode.keys()) + if sampled_episode["total_timesteps"] < n_trajectories: + del valid_episode_indices[index] + else: + episodes.append(sampled_episode) + samples += 1 + counts[index] = 1 + + result = [] + for episode in episodes: + for i in range(counts[episode["id"]]): + print(episode["total_timesteps"]) + print(trajectory_length) + if episode["total_timesteps"] == trajectory_length: + result.append(EpisodeData(** episode)) + elif episode["total_timesteps"] > trajectory_length: + start = randrange(0, episode["total_timesteps"]-trajectory_length) + trajectory = EpisodeData( id= episode["id"], actions = episode["actions"][start:start+trajectory_length],seed= episode["seed"], observations = episode["observations"][start:start+trajectory_length+1], truncations = episode["truncations"][start:start+trajectory_length],terminations = episode["terminations"][start:start+trajectory_length], rewards = episode["rewards"][start:start+trajectory_length],total_timesteps = trajectory_length) + result.append(trajectory) + else: + assert False + return result + + def iterate_episodes( self, episode_indices: Optional[List[int]] = None ) -> Iterator[EpisodeData]: diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 5a2be869..44b77f7a 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -336,6 +336,46 @@ def filter_by_index(episode: Any): env.close() + +@pytest.mark.parametrize( + "dataset_id,env_id", + [ + ("cartpole-test-v0", "CartPole-v1"), + ("dummy-dict-test-v0", "DummyDictEnv-v0"), + ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), + ("dummy-combo-test-v0", "DummyComboEnv-v0"), + ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), + ], +) +def test_sample_trajectories(dataset_id, env_id): + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + env = gym.make(env_id) + + env = DataCollectorV0(env) + num_episodes = 10 + + dataset = create_dummy_dataset_with_collecter_env_helper( + dataset_id, env, num_episodes=num_episodes + ) + + + episodes = dataset.sample_trajectories(4,5) + + + check_episode_data_integrity( + episodes, + dataset.spec.observation_space, + dataset.spec.action_space, + ) + + env.close() + + + @pytest.mark.parametrize( "dataset_id,env_id", [ From 5a4a17acd7cff6623408f5d3c19661440980971a Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Thu, 10 Aug 2023 03:46:09 -0400 Subject: [PATCH 5/5] small typo fix --- minari/dataset/minari_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index e2ea19d3..61060df6 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -254,8 +254,8 @@ def sample_trajectories(self, n_trajectories, trajectory_length, allow_restarts= #We only want to load each episode once. We need to discard episodes that are too short for our trajectory length. - #We mark such episodes so they will not be sampled again, while always preserving uniform random sampling overal all - #samples not known to be invalid. + #We mark such episodes so they will not be sampled again, while always preserving uniform random sampling over all + #episodes not known to be invalid. valid_episode_indices = {key:1 for key in range(0,self.spec.total_episodes)} counts = {} episodes = []