forked from duckietown/gym-duckietown
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sanity_test.py
88 lines (70 loc) · 2.37 KB
/
sanity_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# For env register
import gym_duckietown
from typing import List
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
from gymnasium.wrappers.resize_observation import ResizeObservation
from gym_duckietown.wrappers import DiscreteWrapper, CropObservation, SegmentMiddleLaneWrapper, SegmentRemoveExtraInfo, SegmentLaneWrapper, TransposeToConv2d
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
import multiprocessing
from .custom_net import CustomCNN
MODEL_PREFIX = "dqn_customnet_stack_sm"
SEED = 2**30 + 8394
THREAD_COUNT = 4
def wrap(env):
env = DiscreteWrapper(env)
env = segment(env)
return env
def segment(env):
env = CropObservation(env, 120)
env = SegmentLaneWrapper(env)
env = SegmentMiddleLaneWrapper(env)
env = SegmentRemoveExtraInfo(env)
env = ResizeObservation(env, (40, 80))
return env
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=3),
)
def train(id):
print(f"Running train on process {id}")
env = make_vec_env("Duckietown-udem1-v0", n_envs=1, wrapper_class=wrap, seed=SEED + 100 *id)
env = VecFrameStack(env, 5)
model = DQN(
policy="CnnPolicy",
batch_size=32,
env=env,
gamma=0.99,
learning_rate=0.00005,
buffer_size=50000,
tensorboard_log="./runs",
learning_starts=10000,
seed=SEED,
policy_kwargs=policy_kwargs,
verbose=2,
optimize_memory_usage=True,
replay_buffer_kwargs={"handle_timeout_termination": False}
)
ckpt_callback = CheckpointCallback(
save_freq=50_000,
save_path="./checkpoints/",
name_prefix=f"{MODEL_PREFIX}_{id}_ckpt",
save_replay_buffer=True,
save_vecnormalize=True,
)
model.learn(
total_timesteps=500_000,
callback=ckpt_callback,
progress_bar=True,
tb_log_name=f"{MODEL_PREFIX}_{id}",
)
model.save(f"{MODEL_PREFIX}_{id}_duck")
print(f"Thread {id} done.")
procs: List[multiprocessing.Process] = []
for i in range(1, THREAD_COUNT+1):
procs.append(multiprocessing.Process(target=train, args=[i]))
for p in procs:
p.start()
for p in procs:
p.join()