diff --git a/gymnax/environments/spaces.py b/gymnax/environments/spaces.py index 4faffda..487d91f 100755 --- a/gymnax/environments/spaces.py +++ b/gymnax/environments/spaces.py @@ -16,18 +16,22 @@ class Space: def sample(self, rng: chex.PRNGKey) -> chex.Array: raise NotImplementedError - def contains(self, x: jnp.int_) -> Any: + def contains(self, x: jnp.ndarray) -> Any: raise NotImplementedError class Discrete(Space): """Minimal jittable class for discrete gymnax spaces.""" - def __init__(self, num_categories: int): + def __init__( + self, + num_categories: int, + dtype: jnp.dtype = jnp.int_ + ): assert num_categories >= 0 self.n = num_categories self.shape = () - self.dtype = jnp.int_ + self.dtype = dtype def sample(self, rng: chex.PRNGKey) -> chex.Array: """Sample random action uniformly from set of categorical choices.""" @@ -35,7 +39,7 @@ def sample(self, rng: chex.PRNGKey) -> chex.Array: rng, shape=self.shape, minval=0, maxval=self.n ).astype(self.dtype) - def contains(self, x: jnp.int_) -> jnp.ndarray: + def contains(self, x: jnp.ndarray) -> jnp.ndarray: """Check whether specific object is within space.""" # type_cond = isinstance(x, self.dtype) # shape_cond = (x.shape == self.shape) @@ -64,7 +68,7 @@ def sample(self, rng: chex.PRNGKey) -> chex.Array: rng, shape=self.shape, minval=self.low, maxval=self.high ).astype(self.dtype) - def contains(self, x: jnp.int_) -> jnp.ndarray: + def contains(self, x: jnp.ndarray) -> jnp.ndarray: """Check whether specific object is within space.""" # type_cond = isinstance(x, self.dtype) # shape_cond = (x.shape == self.shape) @@ -89,7 +93,7 @@ def sample(self, rng: chex.PRNGKey) -> Any: # Dict: ] ) - def contains(self, x: jnp.int_) -> bool: + def contains(self, x: jnp.ndarray) -> bool: """Check whether dimensions of object are within subspace.""" # type_cond = isinstance(x, Dict) # num_space_cond = len(x) != len(self.spaces) @@ -112,7 +116,7 @@ def sample(self, rng: chex.PRNGKey) -> Any: # Tuple[chex.Array]: key_split = jax.random.split(rng, self.num_spaces) return tuple([s.sample(key_split[i]) for i, s in enumerate(self.spaces)]) - def contains(self, x: jnp.int_) -> bool: + def contains(self, x: jnp.ndarray) -> bool: """Check whether dimensions of object are within subspace.""" # type_cond = isinstance(x, tuple) # num_space_cond = len(x) != len(self.spaces)