-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_6DOF_sb_integration.py
47 lines (34 loc) · 1.29 KB
/
test_6DOF_sb_integration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Script to test functionality of the 6DOF environment
"""
from my_environment.envs import Rocket6DOF
from stable_baselines3.ppo import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from gym.wrappers import RecordVideo
from stable_baselines3.common.env_checker import check_env
# Import the initial conditions from the setup file
import yaml
from yaml.loader import SafeLoader
with open("config.yaml") as f:
config=yaml.load(f,Loader=SafeLoader)
sb3_config = config["sb3_config"]
env_config = config["env_config"]
kwargs = env_config
# Instantiate the environment
env = Rocket6DOF(**kwargs)
# Check for the environment compatibility with gym and sb3
check_env(env, skip_render_check=False)
env.close()
del env
env = Rocket6DOF(**kwargs)
# Test usage with stable_baselines_3 model
model = PPO('MlpPolicy', env, verbose=1)
# Use a separate environement for evaluation
eval_env = RecordVideo(env = Rocket6DOF(**kwargs),video_folder='6DOF_videos',video_length=500)
import time
start_time = time.time()
# Random Agent, before training
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=5,render=True)
finish_time = time.time()
print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")
print(f"time to record the episodes: {finish_time-start_time}")