-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
40 lines (33 loc) · 1.15 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from env_random import QuixoEnv as QuixoEnvRandom
from env_previous import QuixoEnv as QuixoEnvPrevious
from main import RandomPlayer
from opponent import Opponent
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
import random
def make_env():
# choice = random.choice([0, 1])
# if choice:
return QuixoEnvPrevious(opponent=Opponent(".//old_results//quixo_ppo_random_opponent_longest"))
# else:
# return QuixoEnvRandom(opponent=RandomPlayer())
vec_env = make_vec_env(make_env, n_envs=5)
# model = PPO.load("quixo_ppo_random_opponent_longest")
# mean_reward, std_reward = evaluate_policy(
# model, vec_env, n_eval_episodes=10, deterministic=True)
# print(mean_reward, std_reward)
env = make_env()
env.reset()
model = PPO.load(".//old_results//quixo_ppo_random_opponent_longest")
obs, _ = env.reset()
done = False
for i in range(1):
obs, _ = env.reset()
done = False
while not done:
action, _ = model.predict(obs)
obs, reward, done, _, info = env.step(action)
env.render()
if done:
break