-
Notifications
You must be signed in to change notification settings - Fork 0
/
environment.py
154 lines (123 loc) · 4.71 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import time
from collections import deque
import numpy as np
from PIL import ImageGrab
from util.button_presser import press_key, release_key
from util.dimensions import get_screen_bbox
from util.preprocessor import Preprocessor
class Environment:
def __init__(self,
crop_style=0,
gray=False,
frame_time=1/8,
read_only=False,
normalize=True,
buffer_size=3,
field='champions',
reward_func=lambda x: 1,):
self.field = field
self.crop_style = crop_style
if crop_style is None:
self.prep = None
else:
self.prep = Preprocessor(crop_style, gray=gray, should_normalize=normalize, field=field)
if gray:
self.ball_obs_dims = self.prep.ball_dims + (buffer_size, )
self.car_obs_dims = self.prep.car_dims + (buffer_size, )
else:
self.ball_obs_dims = self.prep.ball_dims + (3 * buffer_size, )
self.car_obs_dims = self.prep.car_dims + (3 * buffer_size, )
self.bboxes = (get_screen_bbox(), )
self.next_frame = time.time()
self.read_only = read_only
self.frame_time = frame_time
self.get_reward = reward_func
self.gray = gray
self.frame_buffer = deque(maxlen=buffer_size) # Save last three frames
self.frame_times = []
self.obs_times = []
def reset(self, read_only=None):
self.frame_times = []
self.obs_times = []
self.fill_buffer()
self.next_frame = time.time()
# True False yes
# True True no
# False True no
# False False yes
if read_only is False or (self.read_only is False and read_only is not True):
release_key('A')
release_key('D')
press_key('W')
press_key('SHIFT')
return self.step(read_only=read_only)[0] # Use step instead of get_obs to initialize reward model
def end(self):
release_key('A')
release_key('D')
release_key('W')
release_key('SHIFT')
def fill_buffer(self):
imgs = self.get_frame()
for i in range(3):
self.frame_buffer.append(imgs)
def get_frame(self):
img = np.array(ImageGrab.grab(bbox=get_screen_bbox()))
if self.crop_style is None:
return img
else:
return self.prep.process_frame(img)
def get_observation(self):
if self.crop_style is None:
return np.concatenate([self.frame_buffer[i] for i in range(self.frame_buffer.maxlen)], axis=2)
if self.gray:
balls = np.stack([self.frame_buffer[i][0] for i in range(self.frame_buffer.maxlen)], axis=2)
cars = np.stack([self.frame_buffer[i][1] for i in range(self.frame_buffer.maxlen)], axis=2)
else:
balls = np.concatenate([self.frame_buffer[i][0] for i in range(self.frame_buffer.maxlen)], axis=2)
cars = np.concatenate([self.frame_buffer[i][1] for i in range(self.frame_buffer.maxlen)], axis=2)
obs = balls, cars
return obs
def step(self, action=0, read_only=None):
sleep_time = self.next_frame - time.time()
if time.time() < self.next_frame:
time.sleep(self.next_frame - time.time())
self.next_frame = time.time() + self.frame_time
if read_only is False or (self.read_only is False and read_only is not True): # read_only overrides
press_key('W')
press_key('SHIFT')
if action == 0:
release_key('A')
release_key('D')
elif action == 1:
press_key('A')
release_key('D')
elif action == 2:
press_key('D')
release_key('A')
frame_start = time.time()
imgs = self.get_frame()
self.frame_times.append(time.time() - frame_start)
r = self.get_reward(imgs)
self.frame_buffer.append(imgs)
obs_start = time.time()
obs = self.get_observation()
self.obs_times.append(time.time() - obs_start)
# done and info are just so the env behaves similar to a gym env, but we don't use them
done = False
info = sleep_time
return obs, r, done, info
def test():
time.sleep(2)
times = []
env = Environment()
env.reset()
for i in range(15*20):
start = time.time()
env.step()
times.append(time.time() - start)
env.end()
print(sum(env.frame_times) / len(env.frame_times))
print(sum(env.obs_times) / len(env.obs_times))
print(sum(times)/len(times))
if __name__ == '__main__':
test()