Skip to content

Commit

Permalink
Order enforcing wrapper fix (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-ackerman committed Jun 21, 2024
1 parent 1eef080 commit 1282a0a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 55 deletions.
14 changes: 0 additions & 14 deletions pettingzoo/utils/env_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,6 @@ def warn_action_out_of_bound(
f"[WARNING]: Received an action {action} that was outside action space {action_space}. Environment is {backup_policy}"
)

@staticmethod
def warn_close_unrendered_env() -> None:
"""Warns: ``[WARNING]: Called close on an unrendered environment.``."""
EnvLogger._generic_warning(
"[WARNING]: Called close on an unrendered environment."
)

@staticmethod
def warn_close_before_reset() -> None:
"""Warns: ``[WARNING]: reset() needs to be called before close.``."""
EnvLogger._generic_warning(
"[WARNING]: reset() needs to be called before close."
)

@staticmethod
def warn_on_illegal_move() -> None:
"""Warns: ``[WARNING]: Illegal move made, game terminating with current player losing.``."""
Expand Down
60 changes: 19 additions & 41 deletions pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,26 @@
class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
"""Checks if function calls or attribute access are in a disallowed order.
* error on getting rewards, terminations, truncations, infos, agent_selection before reset
* error on calling step, observe before reset
* error on iterating without stepping or resetting environment.
* warn on calling close before render or reset
* warn on calling step after environment is terminated or truncated
The following are raised:
* AttributeError if any of the following are accessed before reset():
rewards, terminations, truncations, infos, agent_selection,
num_agents, agents.
* An error if any of the following are called before reset:
render(), step(), observe(), state(), agent_iter()
* A warning if step() is called when there are no agents remaining.
"""

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
assert isinstance(
env, AECEnv
), "OrderEnforcingWrapper is only compatible with AEC environments"
self._has_reset = False
self._has_rendered = False
self._has_updated = False
super().__init__(env)

def __getattr__(self, value: str) -> Any:
"""Raises an error message when data is gotten from the env.
Should only be gotten after reset
"""
if value == "unwrapped":
return self.env.unwrapped
elif value == "render_mode" and hasattr(self.env, "render_mode"):
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
elif value == "possible_agents":
try:
return self.env.possible_agents
except AttributeError:
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
elif value == "observation_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
)
elif value == "action_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `action_space` method instead"
)
elif value == "agent_order":
raise AttributeError(
"agent_order has been removed from the API. Please consider using agent_iter instead."
)
elif (
"""Raises an error if certain data is accessed before reset."""
if (
value
in {
"rewards",
Expand All @@ -75,13 +52,11 @@ def __getattr__(self, value: str) -> Any:
and not self._has_reset
):
raise AttributeError(f"{value} cannot be accessed before reset")
else:
return super().__getattr__(value)
return super().__getattr__(value)

def render(self) -> None | np.ndarray | str | list:
if not self._has_reset:
EnvLogger.error_render_before_reset()
self._has_rendered = True
return super().render()

def step(self, action: ActionType) -> None:
Expand All @@ -90,7 +65,6 @@ def step(self, action: ActionType) -> None:
elif not self.agents:
self._has_updated = True
EnvLogger.warn_step_after_terminated_truncated()
return None
else:
self._has_updated = True
super().step(action)
Expand Down Expand Up @@ -124,8 +98,7 @@ def __str__(self) -> str:
if self.__class__ is OrderEnforcingWrapper
else f"{type(self).__name__}<{str(self.env)}>"
)
else:
return repr(self)
return repr(self)


class AECOrderEnforcingIterable(AECIterable[AgentID, ObsType, ActionType]):
Expand All @@ -134,11 +107,16 @@ def __iter__(self) -> AECOrderEnforcingIterator[AgentID, ObsType, ActionType]:


class AECOrderEnforcingIterator(AECIterator[AgentID, ObsType, ActionType]):
def __init__(
self, env: OrderEnforcingWrapper[AgentID, ObsType, ActionType], max_iter: int
):
assert isinstance(
env, OrderEnforcingWrapper
), "env must be wrapped by OrderEnforcingWrapper"
super().__init__(env, max_iter)

def __next__(self) -> AgentID:
agent = super().__next__()
assert hasattr(
self.env, "_has_updated"
), "env must be wrapped by OrderEnforcingWrapper"
assert (
self.env._has_updated # pyright: ignore[reportGeneralTypeIssues]
), "need to call step() or reset() in a loop over `agent_iter`"
Expand Down

0 comments on commit 1282a0a

Please sign in to comment.