Skip to content

Commit 445d142

Browse files
Fixed bug in dataset.py
- unsqueeze does not exist in NumPy - now using the array backend
1 parent 3295221 commit 445d142

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mushroom_rl/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def compute_J(self, gamma=1.):
477477
r_ep = split_episodes(self.last, self.reward)
478478

479479
if len(r_ep.shape) == 1:
480-
r_ep = r_ep.unsqueeze(0)
480+
r_ep = self._array_backend.expand_dims(r_ep, 0)
481481
if self._dataset_info.backend == 'torch':
482482
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device)
483483
else:

0 commit comments

Comments
 (0)