Skip to content

Commit

Permalink
fix(zjow): fix bug in cliffwalking env (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen authored Dec 15, 2023
1 parent 7342585 commit 1e6f351
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dizoo/cliffwalking/entry/cliffwalking_dqn_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str):
main_config.exp_name = f'cliffwalking_dqn_seed0_deploy'
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
env = CliffWalkingEnv(cfg.env.spec)
env = CliffWalkingEnv(cfg.env)
env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video')
model = DQN(**cfg.policy.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
Expand Down
9 changes: 7 additions & 2 deletions dizoo/cliffwalking/envs/cliffwalking_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def __init__(self, cfg: dict) -> None:
self._init_flag = False
self._replay_path = None
self._observation_space = gym.spaces.Box(low=0, high=1, shape=(48, ), dtype=np.float32)
self._env = gym.make(
"CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
)
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
Expand Down Expand Up @@ -64,8 +67,10 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
np.random.seed(seed)

def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
if isinstance(action, np.ndarray) and action.shape == (1, ):
action = action.squeeze() # 0-dim array
if isinstance(action, np.ndarray):
if action.shape == (1, ):
action = action.squeeze() # 0-dim array
action = action.item()
obs, reward, done, info = self._env.step(action)
obs_encode = self._encode_obs(obs)
self._eval_episode_return += reward
Expand Down

0 comments on commit 1e6f351

Please sign in to comment.