Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order enforcing wrapper fix #1205

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions pettingzoo/utils/env_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,6 @@ def warn_action_out_of_bound(
f"[WARNING]: Received an action {action} that was outside action space {action_space}. Environment is {backup_policy}"
)

@staticmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why exactly was this stuff removed? Looks kind of messy code and isn’t done in the other wrappers so can probably see it just want to make sure it’s on purpose that it got removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is dead code that is never called. I think it was maybe intended to be used in the OrderEnfocingWrapper - and maybe was used in an older version of it, but it's not used anymore. The code is specific to that wrapper, if it's not used there, there's no other place it would be useful.

def warn_close_unrendered_env() -> None:
"""Warns: ``[WARNING]: Called close on an unrendered environment.``."""
EnvLogger._generic_warning(
"[WARNING]: Called close on an unrendered environment."
)

@staticmethod
def warn_close_before_reset() -> None:
"""Warns: ``[WARNING]: reset() needs to be called before close.``."""
EnvLogger._generic_warning(
"[WARNING]: reset() needs to be called before close."
)

@staticmethod
def warn_on_illegal_move() -> None:
"""Warns: ``[WARNING]: Illegal move made, game terminating with current player losing.``."""
Expand Down
60 changes: 19 additions & 41 deletions pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,26 @@
class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
"""Checks if function calls or attribute access are in a disallowed order.

* error on getting rewards, terminations, truncations, infos, agent_selection before reset
* error on calling step, observe before reset
* error on iterating without stepping or resetting environment.
* warn on calling close before render or reset
* warn on calling step after environment is terminated or truncated
The following are raised:
* AttributeError if any of the following are accessed before reset():
rewards, terminations, truncations, infos, agent_selection,
num_agents, agents.
* An error if any of the following are called before reset:
render(), step(), observe(), state(), agent_iter()
* A warning if step() is called when there are no agents remaining.
"""

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
assert isinstance(
env, AECEnv
), "OrderEnforcingWrapper is only compatible with AEC environments"
self._has_reset = False
self._has_rendered = False
self._has_updated = False
super().__init__(env)

def __getattr__(self, value: str) -> Any:
"""Raises an error message when data is gotten from the env.

Should only be gotten after reset
"""
if value == "unwrapped":
return self.env.unwrapped
elif value == "render_mode" and hasattr(self.env, "render_mode"):
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
elif value == "possible_agents":
try:
return self.env.possible_agents
except AttributeError:
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
elif value == "observation_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
)
elif value == "action_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `action_space` method instead"
)
elif value == "agent_order":
raise AttributeError(
"agent_order has been removed from the API. Please consider using agent_iter instead."
)
elif (
"""Raises an error if certain data is accessed before reset."""
if (
value
in {
"rewards",
Expand All @@ -75,13 +52,11 @@ def __getattr__(self, value: str) -> Any:
and not self._has_reset
):
raise AttributeError(f"{value} cannot be accessed before reset")
else:
return super().__getattr__(value)
return super().__getattr__(value)

def render(self) -> None | np.ndarray | str | list:
if not self._has_reset:
EnvLogger.error_render_before_reset()
self._has_rendered = True
return super().render()

def step(self, action: ActionType) -> None:
Expand All @@ -90,7 +65,6 @@ def step(self, action: ActionType) -> None:
elif not self.agents:
self._has_updated = True
EnvLogger.warn_step_after_terminated_truncated()
return None
else:
self._has_updated = True
super().step(action)
Expand Down Expand Up @@ -124,8 +98,7 @@ def __str__(self) -> str:
if self.__class__ is OrderEnforcingWrapper
else f"{type(self).__name__}<{str(self.env)}>"
)
else:
return repr(self)
return repr(self)


class AECOrderEnforcingIterable(AECIterable[AgentID, ObsType, ActionType]):
Expand All @@ -134,11 +107,16 @@ def __iter__(self) -> AECOrderEnforcingIterator[AgentID, ObsType, ActionType]:


class AECOrderEnforcingIterator(AECIterator[AgentID, ObsType, ActionType]):
def __init__(
self, env: OrderEnforcingWrapper[AgentID, ObsType, ActionType], max_iter: int
):
assert isinstance(
env, OrderEnforcingWrapper
), "env must be wrapped by OrderEnforcingWrapper"
super().__init__(env, max_iter)

def __next__(self) -> AgentID:
agent = super().__next__()
assert hasattr(
self.env, "_has_updated"
), "env must be wrapped by OrderEnforcingWrapper"
assert (
self.env._has_updated # pyright: ignore[reportGeneralTypeIssues]
), "need to call step() or reset() in a loop over `agent_iter`"
Expand Down
Loading