-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix space utils for Discrete with non-zero start (#2645)
* Fix flatten utils to handle Discrete.start * Fix vector space utils to handle Discrete.start * More granular dispatch in vector utils * Fix Box including the high end of the interval
- Loading branch information
1 parent
108f32c
commit e671aa1
Showing
6 changed files
with
48 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,37 +43,44 @@ def batch_space(space, n=1): | |
|
||
|
||
@batch_space.register(Box) | ||
@batch_space.register(Discrete) | ||
@batch_space.register(MultiDiscrete) | ||
@batch_space.register(MultiBinary) | ||
def batch_space_base(space, n=1): | ||
if isinstance(space, Box): | ||
repeats = tuple([n] + [1] * space.low.ndim) | ||
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats) | ||
return Box(low=low, high=high, dtype=space.dtype) | ||
def _batch_space_box(space, n=1): | ||
repeats = tuple([n] + [1] * space.low.ndim) | ||
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats) | ||
return Box(low=low, high=high, dtype=space.dtype) | ||
|
||
elif isinstance(space, Discrete): | ||
|
||
@batch_space.register(Discrete) | ||
def _batch_space_discrete(space, n=1): | ||
if space.start == 0: | ||
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype)) | ||
else: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
RedTachyon
Contributor
|
||
return Box( | ||
low=space.start, | ||
high=space.start + space.n - 1, | ||
shape=(n,), | ||
dtype=space.dtype, | ||
) | ||
|
||
elif isinstance(space, MultiDiscrete): | ||
repeats = tuple([n] + [1] * space.nvec.ndim) | ||
high = np.tile(space.nvec, repeats) - 1 | ||
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype) | ||
|
||
elif isinstance(space, MultiBinary): | ||
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype) | ||
@batch_space.register(MultiDiscrete) | ||
def _batch_space_multidiscrete(space, n=1): | ||
repeats = tuple([n] + [1] * space.nvec.ndim) | ||
high = np.tile(space.nvec, repeats) - 1 | ||
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype) | ||
|
||
else: | ||
raise ValueError(f"Space type `{type(space)}` is not supported.") | ||
|
||
@batch_space.register(MultiBinary) | ||
def _batch_space_multibinary(space, n=1): | ||
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype) | ||
|
||
|
||
@batch_space.register(Tuple) | ||
def batch_space_tuple(space, n=1): | ||
def _batch_space_tuple(space, n=1): | ||
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces)) | ||
|
||
|
||
@batch_space.register(Dict) | ||
def batch_space_dict(space, n=1): | ||
def _batch_space_dict(space, n=1): | ||
return Dict( | ||
OrderedDict( | ||
[ | ||
|
@@ -85,7 +92,7 @@ def batch_space_dict(space, n=1): | |
|
||
|
||
@batch_space.register(Space) | ||
def batch_space_custom(space, n=1): | ||
def _batch_space_custom(space, n=1): | ||
return Tuple(tuple(space for _ in range(n))) | ||
|
||
|
||
|
@@ -130,22 +137,22 @@ def iterate(space, items): | |
|
||
|
||
@iterate.register(Discrete) | ||
def iterate_discrete(space, items): | ||
def _iterate_discrete(space, items): | ||
raise TypeError("Unable to iterate over a space of type `Discrete`.") | ||
|
||
|
||
@iterate.register(Box) | ||
@iterate.register(MultiDiscrete) | ||
@iterate.register(MultiBinary) | ||
def iterate_base(space, items): | ||
def _iterate_base(space, items): | ||
try: | ||
return iter(items) | ||
except TypeError: | ||
raise TypeError(f"Unable to iterate over the following elements: {items}") | ||
|
||
|
||
@iterate.register(Tuple) | ||
def iterate_tuple(space, items): | ||
def _iterate_tuple(space, items): | ||
# If this is a tuple of custom subspaces only, then simply iterate over items | ||
if all( | ||
isinstance(subspace, Space) | ||
|
@@ -160,7 +167,7 @@ def iterate_tuple(space, items): | |
|
||
|
||
@iterate.register(Dict) | ||
def iterate_dict(space, items): | ||
def _iterate_dict(space, items): | ||
keys, values = zip( | ||
*[ | ||
(key, iterate(subspace, items[key])) | ||
|
@@ -172,7 +179,7 @@ def iterate_dict(space, items): | |
|
||
|
||
@iterate.register(Space) | ||
def iterate_custom(space, items): | ||
def _iterate_custom(space, items): | ||
raise CustomSpaceError( | ||
f"Unable to iterate over {items}, since {space} " | ||
"is a custom `gym.Space` instance (i.e. not one of " | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
it seems strange that the result of batching of discrete spaces depends on the
start
. Wouldn't it be better to always return Box? Or maybe addstart
to MultiDiscrete?