Skip to content

Commit

Permalink
Allow overriding dtype in Discrete
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvangelderen committed Nov 4, 2024
1 parent aef77d5 commit 642340c
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions gymnax/environments/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,30 @@ class Space:
def sample(self, rng: chex.PRNGKey) -> chex.Array:
raise NotImplementedError

def contains(self, x: jnp.int_) -> Any:
def contains(self, x: jnp.ndarray) -> Any:
raise NotImplementedError


class Discrete(Space):
"""Minimal jittable class for discrete gymnax spaces."""

def __init__(self, num_categories: int):
def __init__(
self,
num_categories: int,
dtype: jnp.dtype = jnp.int_
):
assert num_categories >= 0
self.n = num_categories
self.shape = ()
self.dtype = jnp.int_
self.dtype = dtype

def sample(self, rng: chex.PRNGKey) -> chex.Array:
"""Sample random action uniformly from set of categorical choices."""
return jax.random.randint(
rng, shape=self.shape, minval=0, maxval=self.n
).astype(self.dtype)

def contains(self, x: jnp.int_) -> jnp.ndarray:
def contains(self, x: jnp.ndarray) -> jnp.ndarray:
"""Check whether specific object is within space."""
# type_cond = isinstance(x, self.dtype)
# shape_cond = (x.shape == self.shape)
Expand Down Expand Up @@ -64,7 +68,7 @@ def sample(self, rng: chex.PRNGKey) -> chex.Array:
rng, shape=self.shape, minval=self.low, maxval=self.high
).astype(self.dtype)

def contains(self, x: jnp.int_) -> jnp.ndarray:
def contains(self, x: jnp.ndarray) -> jnp.ndarray:
"""Check whether specific object is within space."""
# type_cond = isinstance(x, self.dtype)
# shape_cond = (x.shape == self.shape)
Expand All @@ -89,7 +93,7 @@ def sample(self, rng: chex.PRNGKey) -> Any: # Dict:
]
)

def contains(self, x: jnp.int_) -> bool:
def contains(self, x: jnp.ndarray) -> bool:
"""Check whether dimensions of object are within subspace."""
# type_cond = isinstance(x, Dict)
# num_space_cond = len(x) != len(self.spaces)
Expand All @@ -112,7 +116,7 @@ def sample(self, rng: chex.PRNGKey) -> Any: # Tuple[chex.Array]:
key_split = jax.random.split(rng, self.num_spaces)
return tuple([s.sample(key_split[i]) for i, s in enumerate(self.spaces)])

def contains(self, x: jnp.int_) -> bool:
def contains(self, x: jnp.ndarray) -> bool:
"""Check whether dimensions of object are within subspace."""
# type_cond = isinstance(x, tuple)
# num_space_cond = len(x) != len(self.spaces)
Expand Down

0 comments on commit 642340c

Please sign in to comment.