-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathapi_socket.py
81 lines (65 loc) · 2.19 KB
/
api_socket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, File
import uvicorn
from utils.inference_socket import Cpu, Gpu
from multiprocessing import Process, Manager
import socket
import selectors
import numpy as np
from utils.serving_args import get_args
selector = selectors.DefaultSelector()
cpu = Cpu()
manager = Manager()
def recvall(sock):
data = bytearray()
while True:
packet = sock.recv(65536)
if not packet:
return data
data.extend(packet)
localhost = "127.0.0.1"
def gpu_listener():
gpu = Gpu(args.model)
def accept_connection(server_socket):
client_socket, addr = server_socket.accept()
selector.register(fileobj=client_socket, events=selectors.EVENT_READ, data=gpu_process)
def gpu_process(client_socket):
batch = recvall(client_socket)
preds = gpu.process(batch)
client_socket.send(preds)
client_socket.close()
selector.unregister(client_socket)
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((localhost, 5000))
server_socket.listen()
selector.register(fileobj=server_socket, events=selectors.EVENT_READ, data=accept_connection)
while True:
events = selector.select()
for key, _ in events:
callback = key.data
callback(key.fileobj)
def start_rest(port: int):
app = get_app()
uvicorn.run(app, host=localhost, port=port)
def get_app():
app = FastAPI()
@app.post("/predictions/resnet")
async def predict(data: list[bytes] = File(...)):
res = cpu.pre_process(data)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((localhost, 5000))
s.sendall(res.tobytes())
s.shutdown(socket.SHUT_WR)
data = recvall(s)
data = np.frombuffer(data, dtype="float32")
data = data.reshape(-1, 1000)
s.close()
return cpu.post_process(data)
return app
if __name__ == "__main__":
args = get_args()
ports = list(range(8080, 8080 + args.ports))
for port in ports:
p = Process(target=start_rest, args=(port,))
p.start()
gpu_listener()