-
Notifications
You must be signed in to change notification settings - Fork 9
/
server.py
70 lines (57 loc) · 1.72 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from aiohttp import web
import socketio
import inference
import pickle
import tensorflow as tf
sio = socketio.AsyncServer()
app = web.Application()
sio.attach(app)
x1 = None
x2 = None
x3 = None
y = None
graph = None
sess = None
word2idx = {}
idx2word = {}
async def index(request):
"""Serve the client-side application."""
with open('index.html') as f:
return web.Response(text=f.read(), content_type='text/html')
@sio.on('connect')
def connect(sid, environ):
print("connect ", sid)
@sio.on('message')
async def message(sid, data):
data = data.split("\n")[0];
print(data, ", sid: ", sid)
test_input = data
input_id = inference.seq(test_input, word2idx)
y_out = sess.run(y, feed_dict={
x1: [input_id],
x2: [len(input_id)],
x3: 1
})
sen = inference.dec(y_out[0], idx2word)
print("chatbot: " + sen)
await sio.emit('message', sen, room=sid)
@sio.on('disconnect')
def disconnect(sid):
print('disconnect ', sid)
app.router.add_static('/static', 'static')
app.router.add_get('/', index)
if __name__ == '__main__':
graph = inference.load_graph("frozen.pb")
sess = tf.InteractiveSession(graph=graph)
x1 = graph.get_tensor_by_name('prefix/encoder_inputs:0')
x2 = graph.get_tensor_by_name('prefix/encoder_inputs_length:0')
x3 = graph.get_tensor_by_name('prefix/batch_size:0')
y = graph.get_tensor_by_name('prefix/decoder/decoder_pred_eval:0')
with open('word2idx.pkl', 'rb') as handle:
word2idx = pickle.load(handle)
with open('idx2word.pkl', 'rb') as handle:
idx2word = pickle.load(handle)
input_id = inference.seq('測試', word2idx)
print('測試: ', input_id)
print('load pickle: done')
web.run_app(app)