Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implemented RepeatAction wrapper #990

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gymnasium/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
TimeLimit,
)
from gymnasium.wrappers.rendering import HumanRendering, RecordVideo, RenderCollection
from gymnasium.wrappers.stateful_action import StickyAction
from gymnasium.wrappers.stateful_action import RepeatAction, StickyAction
from gymnasium.wrappers.stateful_observation import (
DelayObservation,
FrameStackObservation,
Expand Down
80 changes: 76 additions & 4 deletions gymnasium/wrappers/stateful_action.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""``StickyAction`` wrapper - There is a probability that the action is taken again."""
"""A collection of wrappers for modifying actions.

* ``StickyAction`` wrapper - There is a probability that the action is taken again.
* ``RepeatAction`` wrapper - Repeat a single action multiple times.
"""
from __future__ import annotations

from typing import Any
from typing import Any, SupportsFloat

import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.error import InvalidProbability


__all__ = ["StickyAction"]
__all__ = ["StickyAction", "RepeatAction"]


class StickyAction(
Expand Down Expand Up @@ -80,3 +84,71 @@ def action(self, action: ActType) -> ActType:

self.last_action = action
return action


class RepeatAction(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Repeatedly executes a given action in the underlying environment.

Upon calling the `step` method of this wrapper, `num_repeats`-many steps will be taken
with the same action in the underlying environment.
The wrapper sums the rewards collected from the underlying environment and returns the last
environment state observed.
If a termination or truncation is encountered during these steps, the wrapper will stop prematurely.
The `info` will additionally contain a field `"num_action_repetitions"`, which specifies
how many steps were actually taken.

Example:
>>> import gymnasium as gym
>>> env = gym.make("CartPole-v1")
>>> wrapped = RepeatAction(env, num_repeats=2)
>>> env.reset(seed=123)
(array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
>>> env.step(0)
(array([ 0.01734283, -0.23932791, -0.02859527, 0.25216764], dtype=float32), 1.0, False, False, {})
>>> env.step(0) # Perform the same action again
(array([ 0.01255627, -0.43403012, -0.02355192, 0.5356957 ], dtype=float32), 1.0, False, False, {})
>>> wrapped.reset(seed=123) # Now we do the same thing with the `RepeatAction` wrapper
(array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
>>> wrapped.step(0)
(array([ 0.01255627, -0.43403012, -0.02355192, 0.5356957 ], dtype=float32), 2.0, False, False, {'num_action_repetitions': 2})
"""

def __init__(self, env: gym.Env[ObsType, ActType], num_repeats: int):
"""Initialize RepeatAction wrapper.

Args:
env (Env): the wrapped environment
num_repeats (int): the maximum number of times to repeat the action
"""
if num_repeats <= 1:
raise ValueError(
f"Number of action repeats should be greater than 1, but got {num_repeats}"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check that num_repeats is an integer (supporting numpy integers)

gym.utils.RecordConstructorArgs.__init__(self, num_repeats=num_repeats)
gym.Wrapper.__init__(self, env)
self._num_repeats = num_repeats

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Repeat `action` several times.

This step method will execute `action` at most `num_repeats`-many times in `self.env`,
or until a termination or truncation is encountered. The reward returned
is the sum of rewards collected from `self.env`. The last observation from the
environment is returned.
"""
num_steps = 0
total_reward = 0
assert self._num_repeats > 0
for _ in range(self._num_repeats):
observation, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
num_steps += 1
if terminated or truncated:
break
info["num_action_repetitions"] = num_steps
return observation, total_reward, terminated, truncated, info
Loading