-
Notifications
You must be signed in to change notification settings - Fork 1
/
gym_test.py
151 lines (125 loc) · 4.65 KB
/
gym_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from NEAT.neatLearner import NeatLearner
from NEAT.utils import print_hyperparameters
import gym
import matplotlib.pyplot as plt
from os.path import join, exists
import pickle
import shutil
import time
SOLUTION_ITERATIONS = 100
SOLUTION_FOUND = 475.0
NUM_GENOMES = 1000
NO_VELOCITY = True
RENDER = False
PLAY_BEST = False
RECORD = False
SAVE_DIR = 'results/new_pole'
BEST_DIR = 'results/new_pole3'
pkl_file = join(SAVE_DIR, 'backup.pkl')
def check_solution(genome):
fitness = 0.0
print('Checking solution')
for _ in range(SOLUTION_ITERATIONS):
observation = env.reset()
done = False
while not done:
if RENDER:
env.render()
out = nl.get_output(get_input(observation), genome)
action = 0 if out < 0.5 else 1
observation, reward, done, info = env.step(action)
fitness += reward
avg_fitness = fitness / SOLUTION_ITERATIONS
print('Avg fitness ', str(avg_fitness))
return avg_fitness > SOLUTION_FOUND, avg_fitness
def save_and_exit(generations, neatLearner, best_fitness_hist, seconds):
with open(join(SAVE_DIR, 'summary.txt'), 'w') as f:
f.write('Trained in {} generations\n'.format(generations))
f.write('Trained in {:.4f} seconds\n'.format(seconds))
f.write('Num Genomes: {}\n'.format(NUM_GENOMES))
f.write('Include Velocity: {}\n'.format(not NO_VELOCITY))
f.write(print_hyperparameters())
with open(join(SAVE_DIR, 'final.pkl'), 'wb') as f:
pickle.dump(neatLearner, f)
with open(join(SAVE_DIR, 'fitness.pkl'), 'wb') as f:
pickle.dump(best_fitness_hist, f)
exit()
def get_input(observation):
if NO_VELOCITY:
return [observation[0], observation[2]]
else:
return observation
def play_best(n, g):
done = False
observation = g.reset()
while not done:
out = n.get_best_output(get_input(observation))
action = 0 if out < 0.5 else 1
observation, reward, done, info = g.step(action)
print('Done playing best')
exit()
if __name__ == '__main__':
start_time = time.time()
if exists(SAVE_DIR):
choice = input(('Save directory {} exists, do you want '
'to remove it and continue? [y/n]\n').format(SAVE_DIR))
if choice == 'y':
shutil.rmtree(SAVE_DIR)
else:
print('Exiting')
exit()
env = gym.make('CartPole-v1')
if RECORD:
env = gym.wrappers.Monitor(env, join(SAVE_DIR, 'video'))
if PLAY_BEST:
with open(join(BEST_DIR, 'final.pkl'), 'rb') as f:
nl = pickle.load(f)
play_best(nl, env)
elif NO_VELOCITY:
nl = NeatLearner(2, 1, NUM_GENOMES, True)
else:
nl = NeatLearner(4, 1, NUM_GENOMES, False)
best_fitness = 0.0
generation = 0
best_fitness_hist = []
while True:
nl.start_generation()
print('Starting Generation {}'.format(nl.generations))
for genome in range(NUM_GENOMES):
fitness = 0.0
observation = env.reset()
done = False
while not done:
if RENDER:
env.render()
out = nl.get_output(get_input(observation), genome)
action = 0 if out < 0.5 else 1
observation, reward, done, info = env.step(action)
fitness += reward
if fitness > 499:
winner, avg = check_solution(genome)
fitness += avg * 2
if winner:
print('Found a winner!')
nl.best_genome = nl.genomes[genome]
nl.assign_fitness(fitness, genome)
nl.save_top_genome(SAVE_DIR, 'best')
best_fitness_hist.append(fitness)
save_and_exit(generation,
nl,
best_fitness_hist,
time.time() - start_time)
nl.assign_fitness(fitness, genome)
generation += 1
nl.end_generation()
print('Num species: {}'.format(len(nl.species)))
if nl.best_genome.fitness > best_fitness:
best_fitness = nl.best_genome.fitness
print('New Best:', best_fitness)
nl.save_top_genome(SAVE_DIR, 'best')
best_fitness_hist.append(best_fitness)
if generation % 30 == 0:
nl.save_top_genome(SAVE_DIR, 'gen{}'.format(generation))
nl.save_species_exemplars(join(SAVE_DIR, 'gen_'+str(generation)))
with open(pkl_file, 'wb') as f:
pickle.dump(nl, f)