-
Notifications
You must be signed in to change notification settings - Fork 0
/
Humanoid_PPO1_Pybullet_Training.py
87 lines (78 loc) · 3.38 KB
/
Humanoid_PPO1_Pybullet_Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gym
import pybulletgym
from stable_baselines.common.callbacks import EvalCallback
from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO1
import wandb
import json
import os
from pathlib import Path
from datetime import date, time, datetime
from mpi4py import MPI
from stable_baselines import logger
from stable_baselines.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines_utils import *
from stable_baselines.bench import Monitor
def train(env_name, num_time_steps, eval_ep, eval_freq, ckpt_freq, load_model=None):
env=gym.make(env_name)
env_ = gym.make(env_name)
rank = MPI.COMM_WORLD.Get_rank()
today = date.today()
today = str(today).replace('-','_')
now = datetime.now()
current_time = now.strftime("%H_%M_%S")
model_name = env_name + '_PPO1_' + today + current_time
Path('./run/'+model_name).mkdir(parents=True, exist_ok=True)
path = os.path.join(os.path.dirname(__file__), './run/' + model_name)
if rank == 0:
env = Monitor(env, filename=path)
############################
# callback #
############################
callbacklist = []
eval_callback = EvalCallback_wandb(env_, n_eval_episodes=eval_ep, eval_freq=eval_freq, log_path=path)
ckpt_callback = CheckpointCallback(save_freq=ckpt_freq, save_path='./run/' + model_name + '/ckpt', name_prefix='')
callbacklist.append(eval_callback)
callbacklist.append(ckpt_callback)
callback = CallbackList(callbacklist)
if load_model:
model = PPO1.load(env=env, load_path=load_model)
else:
model = PPO1(MlpPolicy, env, verbose=1, gamma = 0.995, clip_param=0.2, entcoeff=1.0, lam = 0.95, optim_epochs=20,optim_batchsize=32768, timesteps_per_actorbatch=320000)
############################
# Logging #
############################
if rank==0:
logger.configure()
config = {}
config['load']=[{'load_model':load_model}]
config['eval']=[{'eval_freq':eval_freq, 'eval_ep':eval_ep}]
config['ckpt']=[{'ckpt_freq':ckpt_freq}]
with open('./run/' + model_name + '/' + model_name + '.txt', 'w+') as outfile:
json.dump(config, outfile, indent=4)
else:
logger.configure(format_strs=[])
############################
# run #
############################
model.learn(total_timesteps=int(num_time_steps), callback=callback)
model.save(path+'/'+model_name)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--env',type=str, default='HumanoidPyBulletEnv-v0')
parser.add_argument('--load_model',type=str, default=None)
parser.add_argument('--nsteps', type=float, default=4e8)
parser.add_argument('--eval_freq', type=int, default=20000)
parser.add_argument('--eval_ep', type=int, default=20)
parser.add_argument('--ckpt_freq', type=int, default=5000)
parser.add_argument('--policy',type=dict, default={'net_arch':[128,64]})
args = parser.parse_args()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
#if rank == 0:
# wandb.init(project='Big_Data_Project')
# print(args.load_model)
train(env_name=args.env, num_time_steps=args.nsteps,
eval_ep=args.eval_ep, eval_freq=args.eval_freq, ckpt_freq=args.ckpt_freq)
# load_model=str(args.load_model))