-
Notifications
You must be signed in to change notification settings - Fork 237
/
Copy pathtest.py
133 lines (107 loc) · 3.04 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pdb
from data import data
from data.data import F, Z
from btc_env import BitcoinEnv, Mode
from hypersearch import HSearchEnv
import pandas as pd
COUNT = 101
def count_rows(*args, **kwargs): return COUNT
def db_to_dataframe_wrapper(direction=1):
def db_to_dataframe(*args, **kwargs):
features = []
for i in range(COUNT):
num = i if direction == 1 else COUNT - i
features.append(dict(
a_o=num,
a_h=num,
a_l=num,
a_c=num,
a_v=num
))
return pd.DataFrame(features)
return db_to_dataframe
def reset(env):
env.start_cash = env.start_value = 1000
env.use_dataset(Mode.TRAIN)
env.reset()
def main():
hs = HSearchEnv(net_type='conv2d')
flat, hydrated, network = hs.get_winner()
flat['unimodal'] = True
flat['arbitrage'] = False
flat['indicators'] = False
flat['step_window'] = 10
data.tables = [
dict(
name='a',
ts='ts',
cols=dict(o=F, h=F, l=F, c=F, v=Z)
)
]
data.target = 'a_c'
data.count_rows = count_rows
data.db_to_dataframe = db_to_dataframe_wrapper(1)
env = BitcoinEnv(flat, name='ppo_agent')
# Hold
reset(env)
for i in range(90): # step_window - start_timestep
next_state, terminal, reward = env.execute(0)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] == 0
# > 1
reset(env)
for i in range(90):
next_state, terminal, reward = env.execute(1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] > 0
# < 1
reset(env)
for i in range(90):
next_state, terminal, reward = env.execute(-1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] < 0
# Try just one
reset(env)
env.execute(0)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] == 0
reset(env)
env.execute(1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] > 0
reset(env)
env.execute(-1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] < 0
# Now for a bear market
data.db_to_dataframe = db_to_dataframe_wrapper(-1)
# Hold
reset(env)
for i in range(90): env.execute(0)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] == 0
# > 1
reset(env)
for i in range(90): env.execute(1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] < 0
# < 1
reset(env)
for i in range(90): env.execute(-1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] > 0
# Try just one
reset(env)
env.execute(0)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] == 0
reset(env)
env.execute(1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] < 0
reset(env)
env.execute(-1)
env.episode_finished(None)
assert env.acc.episode.advantages[-1] > 0
if __name__ == '__main__':
main()