Skip to content

Commit

Permalink
polish(pu): pistonball reuse PTZRecordVideo
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 22, 2024
1 parent d1e427e commit e916841
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 54 deletions.
53 changes: 6 additions & 47 deletions dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,16 @@
from typing import Any, List, Union, Optional, Dict
import gymnasium as gym
import numpy as np
from functools import reduce
from typing import List, Optional, Dict

from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper
from ding.torch_utils import to_ndarray, to_list
import gymnasium as gym
import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY
from pettingzoo.utils.conversions import parallel_wrapper_fn
from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PTZRecordVideo
from pettingzoo.butterfly import pistonball_v6


# Custom wrapper for recording videos in PettingZoo environments
class PTZRecordVideo(gym.wrappers.RecordVideo):
def step(self, action):
"""
Custom step function for handling PettingZoo environments
with gymnasium's RecordVideo wrapper.
"""
observations, rewards, terminateds, truncateds, infos = self.env.step(action)

# Check if any agent has terminated or truncated
if not (self.terminated is True or self.truncated is True):
self.step_id += 1
if not self.is_vector_env:
if terminateds or truncateds:
self.episode_id += 1
self.terminated = terminateds
self.truncated = truncateds
elif terminateds[0] or truncateds[0]:
self.episode_id += 1
self.terminated = terminateds[0]
self.truncated = truncateds[0]

# Capture the video frame if recording
if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0 and self.recorded_frames > self.video_length:
self.close_video_recorder()
elif not self.is_vector_env:
if terminateds is True or truncateds is True:
self.close_video_recorder()
elif terminateds[0] or truncateds[0]:
self.close_video_recorder()

elif self._video_enabled():
self.start_video_recorder()

return observations, rewards, terminateds, truncateds, infos


@ENV_REGISTRY.register('petting_zoo_pistonball')
class PettingZooPistonballEnv(BaseEnv):
"""
Expand Down
10 changes: 3 additions & 7 deletions dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@
from pettingzoo.mpe.simple_spread.simple_spread import Scenario


# Custom wrapper for recording videos in PettingZoo environments
class PTZRecordVideo(gym.wrappers.RecordVideo):
def step(self, action):
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
# gymnasium==0.27.1
(
observations,
rewards,
terminateds,
truncateds,
infos,
) = self.env.step(action)
observations, rewards, terminateds, truncateds, infos = self.env.step(action)

# Because pettingzoo returns a dict of terminated and truncated, we need to check if any of the values are True
if not (self.terminated is True or self.truncated is True): # the first location for modifications
Expand All @@ -39,6 +34,7 @@ def step(self, action):
self.terminated = terminateds[0]
self.truncated = truncateds[0]

# Capture the video frame if recording
if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
Expand Down

0 comments on commit e916841

Please sign in to comment.