-
Notifications
You must be signed in to change notification settings - Fork 6
/
AI_server.py
279 lines (246 loc) · 10.6 KB
/
AI_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
#!/usr/bin/python
# -*- coding: utf-8 -*-
# author: <[email protected]>
import json
import traceback
from datetime import timedelta
import tensorflow as tf
from flask import Flask, request, jsonify, make_response, request, current_app
from functools import update_wrapper
import Renju as renju
def crossdomain(origin=None, methods=None, headers=None,
max_age=21600, attach_to_all=True,
automatic_options=True):
if methods is not None:
methods = ', '.join(sorted(x.upper() for x in methods))
if headers is not None and not isinstance(headers, basestring):
headers = ', '.join(x.upper() for x in headers)
if not isinstance(origin, basestring):
origin = ', '.join(origin)
if isinstance(max_age, timedelta):
max_age = max_age.total_seconds()
def get_methods():
if methods is not None:
return methods
options_resp = current_app.make_default_options_response()
return options_resp.headers['allow']
def decorator(f):
def wrapped_function(*args, **kwargs):
if automatic_options and request.method == 'OPTIONS':
resp = current_app.make_default_options_response()
else:
resp = make_response(f(*args, **kwargs))
if not attach_to_all and request.method != 'OPTIONS':
return resp
h = resp.headers
h['Access-Control-Allow-Origin'] = origin
h['Access-Control-Allow-Methods'] = get_methods()
h['Access-Control-Max-Age'] = str(max_age)
if headers is not None:
h['Access-Control-Allow-Headers'] = headers
return resp
f.provide_automatic_options = False
return update_wrapper(wrapped_function, f)
return decorator
def get_parameter(key, default=None, dtype=None):
val = request.args.get(key)
if val is None:
val = default
if dtype is not None:
try:
val = dtype(val)
except:
return default
return val
def response(**msg):
return jsonify(msg)
# ref: documentation of Flask, [ http://flask.pocoo.org/docs/0.10/quickstart/ ]
app = Flask(__name__)
args = None
model = None
thread_pool = None
model_dir_dict = None
@app.route('/')
def index():
return jsonify({'index': "index"})
@app.route('/switch_model', methods=['GET'])
def render_switch_model():
"""
inner API, not opened
support params: (model, "model type, choices=[policy_dl, policy_rl, policy_rollout, value_net]")
(file, "model file path, empty string means most recent checkpoint")
:return:
"""
try:
global model, model_dir_dict
model_type = get_parameter("model")
model_file = get_parameter("file", default=None)
# if model is not None and (model["type"] != model_type or model["file"] != model_file):
# # switch model
# model["model"].close()
# del model
# model = None
# renju.logger.info("close old model, type=%s, file=%s" % (model["type"], model["file"]))
model_dir = model_dir_dict[model_type]
# find avaliable model file
if model_file is None:
checkpoint = tf.train.get_checkpoint_state(model_dir)
if not (checkpoint and checkpoint.model_checkpoint_path):
renju.logger.warn("switch model error, not found avaliable model file")
return response(status=2)
model_file = checkpoint.model_checkpoint_path
model_file = model_file[model_file.rfind("/") + 1:]
# check model type
if model["type"] != model_type:
renju.logger.warn("switch model error, model type not equal, (%s, %s)" % (model_type, model["type"]))
return response(status=2)
if model is None:
model = {"model": renju.load_model(args, model_type, model_file),
"type": model_type,
"dir": model_dir + "/",
"file": model_file}
else:
if model_file != model["file"]:
model["model"].saver.restore(model["model"].session, model["dir"] + model_file)
renju.logger.info("successful load model file: %s" % model_file)
return response(status=0)
except:
renju.logger.warn("switch model error, detail=%s" % traceback.format_exc())
return response(status=2)
@app.route('/action', methods=['POST', 'OPTIONS'])
@crossdomain(origin='*', headers="Origin, X-Requested-With, Content-Type, Accept")
def render_action():
"""
support params: (model, "model type, choices=[policy_dl, policy_rl, policy_rollout, value_net]")
(board, "board stream")
(player, "current player, choices=[black, white]")
:return:
"""
try:
global model
model_type = get_parameter("model")
post_data = json.loads(request.data)
board_stream = post_data["board"].strip()
player = post_data["player"].strip()
if model is None or model["type"] != model_type:
renju.logger.error("model is None or model type not match, please check!")
return response(status=2)
else:
board = renju.stream_to_board(board_stream)
action = renju.Utility.timeit(lambda: renju.action_model(model_type, model["model"], board, player),
desc="action policy dl, player=%s" % player)
# desc="action policy dl, board=%s, player=%s" % (board_stream, player))
if action is not None:
if model_type != "value_net":
return response(status=0, type="decision", action=action.tolist())
else:
return response(status=0, type="decision", value=action[0])
else:
return response(status=2)
except:
renju.logger.warn("action error, detail=%s" % traceback.format_exc())
return response(status=2)
@app.route('/simulate', methods=['POST', 'OPTIONS'])
@crossdomain(origin='*', headers="Origin, X-Requested-With, Content-Type, Accept")
def render_simulate():
"""
support params: (model, "model type, choices=[policy_dl, policy_rl, policy_rollout, value_net]")
(board, "board stream")
(player, "current player, choices=[black, white]")
:return:
"""
try:
global model
post_data = json.loads(request.data)
board = renju.stream_to_board(post_data["board"].strip())
player = post_data["player"].strip()
reward = renju.simulate(model["type"], model["model"], board, player)
return response(status=0, reward=reward)
except:
renju.logger.warn("simulate error, detail=%s" % traceback.format_exc())
return response(status=2)
@app.route('/play', methods=['POST', 'OPTIONS'])
@crossdomain(origin='*', headers="Origin, X-Requested-With, Content-Type, Accept")
def render_play():
"""
main API
support params: (board, "board stream")
(player, "current player, choices=[black, white]")
:return:
"""
try:
global thread_pool
post_data = json.loads(request.data)
# board_stream = post_data["board"].strip()
# player = str(post_data["player"].strip())
op_action = int(post_data["action"])
auth_name = str(post_data["auth"]).strip()
if not thread_pool.check_auth(auth_name):
return response(status=1, msg="not avaliable auth")
action = thread_pool.decision(op_action, auth_name)
return response(status=0, action=action, type="play")
except:
renju.logger.warn("play error, detail=%s" % traceback.format_exc())
return response(status=2)
@app.route('/operate', methods=['GET'])
@crossdomain(origin='*', headers="Origin, X-Requested-With, Content-Type, Accept")
def render_connect():
"""
play API for connection
support params: (handle, choices=[connect, release, undo])
(player, "current player color, choices=[black, white], optional")
:return: avaliable thread name
"""
try:
global thread_pool
handle = get_parameter("handle").strip()
if handle == "connect":
player = str(get_parameter("player", default="black"))
auth_name = thread_pool.acquire_thread(player)
if auth_name is None:
return response(status=1, msg="no avaliable auth")
else:
return response(status=0, msg=auth_name)
elif handle == "release":
auth_name = get_parameter("auth")
thread_pool.free_thread(auth_name)
return response(status=0, msg="release connect")
elif handle == "undo":
auth_name = get_parameter("auth")
return response(status=0, msg="undo finish")
else:
return response(status=1, msg="unknown request handle")
except:
renju.logger.warn("game handle error, detail=%s" % traceback.format_exc())
return response(status=2)
if __name__ == '__main__':
arg_parser = renju.parser_argument()
args = arg_parser.parse_args()
model_dir_dict = {"policy_dl": args.policy_dl_models_dir,
"policy_rl": args.policy_rl_models_dir,
"policy_rollout": args.policy_rollout_models_dir,
"value_net": args.values_net_models_dir}
# model = renju.load_model(args, args.model_type)
model_type = args.model_type
if model_type == "policy_dl":
ip_port = args.policy_dl_ip_port
elif model_type == "policy_rl":
ip_port = args.policy_rl_ip_port
elif model_type == "policy_rollout":
ip_port = args.policy_rollout_ip_port
elif model_type == "value_net":
ip_port = args.value_net_ip_port
else:
ip_port = args.main_ip_port
host, port = ip_port.split(":")
if ip_port != args.main_ip_port:
model = {"model": renju.load_model(args, model_type), "type": model_type,
"dir": model_dir_dict[model_type] + "/", "file": None}
else:
rpc = renju.ModelRPC(args)
model = renju.Utility.timeit(
lambda: renju.MCTS(rpc, visit_threshold=args.mcts_visit_threshold, virtual_loss=args.mcts_virtual_loss,
explore_rate=args.mcts_explore_rate, mix_lambda=args.mcts_mix_lambda),
desc="load MCTS module")
thread_pool = renju.MCTSThreadPool(model, play_jobs=5, simulate_jobs=3)
app.run(host=host, port=int(port), debug=False)