diff --git a/pettingzoo/utils/env_logger.py b/pettingzoo/utils/env_logger.py index c5e640e47..bd505e2e3 100644 --- a/pettingzoo/utils/env_logger.py +++ b/pettingzoo/utils/env_logger.py @@ -61,20 +61,6 @@ def warn_action_out_of_bound( f"[WARNING]: Received an action {action} that was outside action space {action_space}. Environment is {backup_policy}" ) - @staticmethod - def warn_close_unrendered_env() -> None: - """Warns: ``[WARNING]: Called close on an unrendered environment.``.""" - EnvLogger._generic_warning( - "[WARNING]: Called close on an unrendered environment." - ) - - @staticmethod - def warn_close_before_reset() -> None: - """Warns: ``[WARNING]: reset() needs to be called before close.``.""" - EnvLogger._generic_warning( - "[WARNING]: reset() needs to be called before close." - ) - @staticmethod def warn_on_illegal_move() -> None: """Warns: ``[WARNING]: Illegal move made, game terminating with current player losing.``.""" diff --git a/pettingzoo/utils/wrappers/order_enforcing.py b/pettingzoo/utils/wrappers/order_enforcing.py index 649c23caa..4a1255682 100644 --- a/pettingzoo/utils/wrappers/order_enforcing.py +++ b/pettingzoo/utils/wrappers/order_enforcing.py @@ -19,11 +19,13 @@ class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]): """Checks if function calls or attribute access are in a disallowed order. - * error on getting rewards, terminations, truncations, infos, agent_selection before reset - * error on calling step, observe before reset - * error on iterating without stepping or resetting environment. - * warn on calling close before render or reset - * warn on calling step after environment is terminated or truncated + The following are raised: + * AttributeError if any of the following are accessed before reset(): + rewards, terminations, truncations, infos, agent_selection, + num_agents, agents. + * An error if any of the following are called before reset: + render(), step(), observe(), state(), agent_iter() + * A warning if step() is called when there are no agents remaining. """ def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): @@ -31,37 +33,12 @@ def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): env, AECEnv ), "OrderEnforcingWrapper is only compatible with AEC environments" self._has_reset = False - self._has_rendered = False self._has_updated = False super().__init__(env) def __getattr__(self, value: str) -> Any: - """Raises an error message when data is gotten from the env. - - Should only be gotten after reset - """ - if value == "unwrapped": - return self.env.unwrapped - elif value == "render_mode" and hasattr(self.env, "render_mode"): - return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues] - elif value == "possible_agents": - try: - return self.env.possible_agents - except AttributeError: - EnvLogger.error_possible_agents_attribute_missing("possible_agents") - elif value == "observation_spaces": - raise AttributeError( - "The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead" - ) - elif value == "action_spaces": - raise AttributeError( - "The base environment does not have an possible_agents attribute. Use the environments `action_space` method instead" - ) - elif value == "agent_order": - raise AttributeError( - "agent_order has been removed from the API. Please consider using agent_iter instead." - ) - elif ( + """Raises an error if certain data is accessed before reset.""" + if ( value in { "rewards", @@ -75,13 +52,11 @@ def __getattr__(self, value: str) -> Any: and not self._has_reset ): raise AttributeError(f"{value} cannot be accessed before reset") - else: - return super().__getattr__(value) + return super().__getattr__(value) def render(self) -> None | np.ndarray | str | list: if not self._has_reset: EnvLogger.error_render_before_reset() - self._has_rendered = True return super().render() def step(self, action: ActionType) -> None: @@ -90,7 +65,6 @@ def step(self, action: ActionType) -> None: elif not self.agents: self._has_updated = True EnvLogger.warn_step_after_terminated_truncated() - return None else: self._has_updated = True super().step(action) @@ -124,8 +98,7 @@ def __str__(self) -> str: if self.__class__ is OrderEnforcingWrapper else f"{type(self).__name__}<{str(self.env)}>" ) - else: - return repr(self) + return repr(self) class AECOrderEnforcingIterable(AECIterable[AgentID, ObsType, ActionType]): @@ -134,11 +107,16 @@ def __iter__(self) -> AECOrderEnforcingIterator[AgentID, ObsType, ActionType]: class AECOrderEnforcingIterator(AECIterator[AgentID, ObsType, ActionType]): + def __init__( + self, env: OrderEnforcingWrapper[AgentID, ObsType, ActionType], max_iter: int + ): + assert isinstance( + env, OrderEnforcingWrapper + ), "env must be wrapped by OrderEnforcingWrapper" + super().__init__(env, max_iter) + def __next__(self) -> AgentID: agent = super().__next__() - assert hasattr( - self.env, "_has_updated" - ), "env must be wrapped by OrderEnforcingWrapper" assert ( self.env._has_updated # pyright: ignore[reportGeneralTypeIssues] ), "need to call step() or reset() in a loop over `agent_iter`"