Skip to content

Commit

Permalink
fix total_steps computation (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Aug 29, 2024
1 parent ed39de1 commit e3cfe2d
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ def __init__(
else:
raise ValueError(f"Unrecognized type {type(data)} for data")

self._total_steps = None
if episode_indices is None:
episode_indices = np.arange(self._data.total_episodes)
self._total_steps = self._data.total_steps
assert episode_indices is not None
self._episode_indices: npt.NDArray[np.int_] = episode_indices
self._total_steps = None

metadata = self._data.metadata

Expand Down Expand Up @@ -307,13 +308,10 @@ def total_episodes(self) -> int:
def total_steps(self) -> int:
"""Total episodes steps in the Minari dataset."""
if self._total_steps is None:
if self.episode_indices is None:
self._total_steps = self.storage.total_steps
else:
self._total_steps = 0
metadatas = self.storage.get_episode_metadata(self.episode_indices)
for m in metadatas:
self._total_steps += m["total_steps"]
self._total_steps = 0
metadatas = self.storage.get_episode_metadata(self.episode_indices)
for m in metadatas:
self._total_steps += m["total_steps"]
return int(self._total_steps)

@property
Expand Down

0 comments on commit e3cfe2d

Please sign in to comment.