Skip to content

Commit

Permalink
Fix space utils for Discrete with non-zero start (#2645)
Browse files Browse the repository at this point in the history
* Fix flatten utils to handle Discrete.start

* Fix vector space utils to handle Discrete.start

* More granular dispatch in vector utils

* Fix Box including the high end of the interval
  • Loading branch information
tristandeleu authored Mar 4, 2022
1 parent 108f32c commit e671aa1
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 28 deletions.
4 changes: 2 additions & 2 deletions gym/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _flatten_box_multibinary(space, x) -> np.ndarray:
@flatten.register(Discrete)
def _flatten_discrete(space, x) -> np.ndarray:
onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x] = 1
onehot[x - space.start] = 1
return onehot


Expand Down Expand Up @@ -124,7 +124,7 @@ def _unflatten_box_multibinary(space: Box | MultiBinary, x: np.ndarray) -> np.nd

@unflatten.register(Discrete)
def _unflatten_discrete(space: Discrete, x: np.ndarray) -> int:
return int(np.nonzero(x)[0][0])
return int(space.start + np.nonzero(x)[0][0])


@unflatten.register(MultiDiscrete)
Expand Down
57 changes: 32 additions & 25 deletions gym/vector/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,44 @@ def batch_space(space, n=1):


@batch_space.register(Box)
@batch_space.register(Discrete)
@batch_space.register(MultiDiscrete)
@batch_space.register(MultiBinary)
def batch_space_base(space, n=1):
if isinstance(space, Box):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype)
def _batch_space_box(space, n=1):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype)

elif isinstance(space, Discrete):

@batch_space.register(Discrete)
def _batch_space_discrete(space, n=1):
if space.start == 0:
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
else:

This comment has been minimized.

Copy link
@pzhokhov

pzhokhov Mar 7, 2022

Collaborator

it seems strange that the result of batching of discrete spaces depends on the start. Wouldn't it be better to always return Box? Or maybe add start to MultiDiscrete?

This comment has been minimized.

Copy link
@RedTachyon

RedTachyon Mar 10, 2022

Contributor

I think it's one of the cases where both options make some sense. Right now it returns the space that "makes the most sense" in a given situation at the cost of some consistency. We could make it consistent, but then we get slightly odd behavior.

Adding start to MultiDiscrete would probably be the best, but that's a whole another PR.

return Box(
low=space.start,
high=space.start + space.n - 1,
shape=(n,),
dtype=space.dtype,
)

elif isinstance(space, MultiDiscrete):
repeats = tuple([n] + [1] * space.nvec.ndim)
high = np.tile(space.nvec, repeats) - 1
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)

elif isinstance(space, MultiBinary):
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space, n=1):
repeats = tuple([n] + [1] * space.nvec.ndim)
high = np.tile(space.nvec, repeats) - 1
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)

else:
raise ValueError(f"Space type `{type(space)}` is not supported.")

@batch_space.register(MultiBinary)
def _batch_space_multibinary(space, n=1):
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)


@batch_space.register(Tuple)
def batch_space_tuple(space, n=1):
def _batch_space_tuple(space, n=1):
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))


@batch_space.register(Dict)
def batch_space_dict(space, n=1):
def _batch_space_dict(space, n=1):
return Dict(
OrderedDict(
[
Expand All @@ -85,7 +92,7 @@ def batch_space_dict(space, n=1):


@batch_space.register(Space)
def batch_space_custom(space, n=1):
def _batch_space_custom(space, n=1):
return Tuple(tuple(space for _ in range(n)))


Expand Down Expand Up @@ -130,22 +137,22 @@ def iterate(space, items):


@iterate.register(Discrete)
def iterate_discrete(space, items):
def _iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.")


@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def iterate_base(space, items):
def _iterate_base(space, items):
try:
return iter(items)
except TypeError:
raise TypeError(f"Unable to iterate over the following elements: {items}")


@iterate.register(Tuple)
def iterate_tuple(space, items):
def _iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
Expand All @@ -160,7 +167,7 @@ def iterate_tuple(space, items):


@iterate.register(Dict)
def iterate_dict(space, items):
def _iterate_dict(space, items):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
Expand All @@ -172,7 +179,7 @@ def iterate_dict(space, items):


@iterate.register(Space)
def iterate_custom(space, items):
def _iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gym.Space` instance (i.e. not one of "
Expand Down
12 changes: 11 additions & 1 deletion tests/spaces/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
),
}
),
Discrete(3, start=2),
Discrete(8, start=-5),
]

flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7]
flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7, 3, 8]


@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
Expand Down Expand Up @@ -123,6 +125,8 @@ def compare_nested(left, right):
np.int64,
np.int8,
np.float64,
np.int64,
np.int64,
]


Expand Down Expand Up @@ -187,6 +191,8 @@ def compare_sample_types(original_space, original_sample, unflattened_sample):
OrderedDict(
[("position", 3), ("velocity", np.array([0.5, 3.5], dtype=np.float32))]
),
3,
-2,
]


Expand All @@ -200,6 +206,8 @@ def compare_sample_types(original_space, original_sample, unflattened_sample):
np.array([1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], dtype=np.int64),
np.array([0, 1, 1, 0, 0, 0, 1, 1, 1, 1], dtype=np.int8),
np.array([0, 0, 0, 1, 0, 0.5, 3.5], dtype=np.float64),
np.array([0, 1, 0], dtype=np.int64),
np.array([0, 0, 0, 1, 0, 0, 0, 0], dtype=np.int64),
]


Expand Down Expand Up @@ -243,6 +251,8 @@ def test_unflatten(space, flattened_sample, expected_sample):
high=np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], dtype=np.float64),
dtype=np.float64,
),
Box(low=0, high=1, shape=(3,), dtype=np.int64),
Box(low=0, high=1, shape=(8,), dtype=np.int64),
]


Expand Down
1 change: 1 addition & 0 deletions tests/vector/test_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Array("B", 1),
Array("B", 32 * 32 * 3),
Array("i", 1),
Array("i", 1),
(Array("i", 1), Array("i", 1)),
(Array("i", 1), Array("f", 2)),
Array("B", 3),
Expand Down
1 change: 1 addition & 0 deletions tests/vector/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
MultiDiscrete([2, 2, 2, 2]),
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
Tuple(
(
Expand Down
1 change: 1 addition & 0 deletions tests/vector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Box(low=0, high=255, shape=(), dtype=np.uint8),
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Discrete(2),
Discrete(5, start=-2),
Tuple((Discrete(3), Discrete(5))),
Tuple(
(
Expand Down

0 comments on commit e671aa1

Please sign in to comment.