Skip to content

Commit b3839e2

Browse files
committed
add stream to stream rpc support
1 parent 5ef0ad3 commit b3839e2

File tree

4 files changed

+93
-59
lines changed

4 files changed

+93
-59
lines changed

README.md

+20-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
> 深度学习模型在落地时需要提供高效快速交互接口,业务逻辑和深度模型解码通常运行在不同类型的机器上。
1717
Http 并不适合大量数据的交互,而RPC (Remote Procedure Call) 远程过程调用, 而RPC在TCP层实现。提高了开发效率,算法工程师可以不必花费更多精力放在具体的接口实现上,而是专注于算法优化上。
1818

19-
tzrpc 框架基于google的 [grpc](https://github.com/grpc/) 实现,需要Python 3.7及以上
19+
tzrpc 框架基于google的 [grpc](https://github.com/grpc/) 实现,需要Python 3.7及以上, 支持流式传输!!!
2020

2121
目前支持以下基础类型:
2222

@@ -83,6 +83,11 @@ def send_bool(_bool: bool):
8383
def send_python_obj(data):
8484
return data
8585

86+
@server.register(stream=True)
87+
def gumbel(num):
88+
if num % 3 == 0:
89+
yield f"number is {num}, you win"
90+
8691
if __name__ == '__main__':
8792
server.run("localhost", 8000)
8893
```
@@ -128,6 +133,11 @@ def send_bool(_bool: bool):
128133
def send_python_obj(data):
129134
return data
130135

136+
@client.register(stream=True)
137+
def gumbel(num):
138+
for i in range(num):
139+
yield i
140+
131141
if __name__ == '__main__':
132142
print(say_hello("lovemefan"))
133143
print(send_numpy_obj())
@@ -145,6 +155,10 @@ if __name__ == '__main__':
145155

146156
python_obj = testOb("test_name", 20)
147157
print(send_python_obj(python_obj).__dict__)
158+
159+
# 流式demo
160+
for i in gumbel(10):
161+
print(i)
148162
```
149163

150164
### 客户端输出
@@ -168,4 +182,9 @@ False
168182
True
169183
170184
{'name': 'test_name', 'age': 20}
185+
186+
number is 0, you win
187+
number is 3, you win
188+
number is 6, you win
189+
number is 9, you win
171190
```

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
setuptools.setup(
1414
name="tzrpc",
15-
version="v0.0.3",
15+
version="v0.0.4",
1616
author="Lovemefan, Yunnan Key Laboratory of Artificial Intelligence, "
1717
"Kunming University of Science and Technology, Kunming, Yunnan ",
1818
author_email="[email protected]",

tzrpc/client/client.py

+58-42
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import numbers
88
import pickle
99
from functools import partial
10-
from typing import Generator
11-
1210
import grpc
1311
import numpy as np
1412

@@ -50,11 +48,14 @@ def __init__(self, server_address: str, debug=False):
5048
if debug:
5149
logger.setLevel(logging.DEBUG)
5250
self.server_address = server_address
53-
self.channel = grpc.insecure_channel(server_address,
54-
options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
55-
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
56-
('grpc.max_metadata_size', MAX_METADATA_SIZE)]
57-
)
51+
self.channel = grpc.insecure_channel(
52+
server_address,
53+
options=[
54+
("grpc.max_send_message_length", MAX_MESSAGE_LENGTH),
55+
("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH),
56+
("grpc.max_metadata_size", MAX_METADATA_SIZE),
57+
],
58+
)
5859

5960
def register(self, func=None, stream=False):
6061
"""
@@ -63,54 +64,69 @@ def register(self, func=None, stream=False):
6364
stream(bool): use stream mode
6465
:return:
6566
"""
67+
6668
# if return_type not in self.__type:
6769
# raise ValueError(f"TZRPC return type only support {self.__type}")
6870
def decorate(_func):
6971
return partial(wrapper, _func)
7072

7173
def wrapper(_func, *args, **kwargs):
7274
stub = toObjectStub(self.channel, _func.__name__)
75+
# logger.debug(f"{self.channel._channel.name}-{_func.__name__}")
7376
result = _func(*args, **kwargs)
74-
if isinstance(result, str):
75-
request = String(text=result)
76-
response = stub.toString(request).text
77+
if not stream:
78+
if isinstance(result, str):
79+
request = String(text=result)
80+
response = stub.toString(request).text
7781

78-
elif isinstance(result, np.ndarray):
79-
request = numpy2protobuf(result)
80-
response = protobuf2numpy(stub.toNdarray(request))
81-
elif isinstance(result, bytes):
82-
request = Bytes(data=result)
83-
response = stub.toBytes(request).data
84-
elif isinstance(result, bool):
85-
request = Boolean(value=result)
86-
response = bool(stub.toBoolean(request).value)
87-
elif isinstance(result, numbers.Number):
88-
if isinstance(result, int):
89-
request = Integer(value=result)
90-
response = stub.toInteger(request).value
91-
elif isinstance(result, float):
92-
request = Double(value=result)
93-
response = stub.toDouble(request).value
82+
elif isinstance(result, np.ndarray):
83+
request = numpy2protobuf(result)
84+
response = protobuf2numpy(stub.toNdarray(request))
85+
elif isinstance(result, bytes):
86+
request = Bytes(data=result)
87+
response = stub.toBytes(request).data
88+
elif isinstance(result, bool):
89+
request = Boolean(value=result)
90+
response = bool(stub.toBoolean(request).value)
91+
elif isinstance(result, numbers.Number):
92+
if isinstance(result, int):
93+
request = Integer(value=result)
94+
response = stub.toInteger(request).value
95+
elif isinstance(result, float):
96+
request = Double(value=result)
97+
response = stub.toDouble(request).value
98+
else:
99+
logger.exception(f"Type of {type(result)} is not support")
100+
elif self.tensor_type is not None and isinstance(
101+
result, self.tensor_type
102+
):
103+
request = numpy2protobuf(result.numpy())
104+
response = torch.from_numpy(
105+
protobuf2numpy(stub.toNdarray(request)).copy()
106+
)
94107
else:
95-
logger.exception(f"Type of {type(result)} is not support")
96-
elif self.tensor_type is not None and isinstance(result, self.tensor_type):
97-
request = numpy2protobuf(result.numpy())
98-
response = torch.from_numpy(
99-
protobuf2numpy(stub.toNdarray(request)).copy()
100-
)
101-
else:
102-
logger.info(f"Data will serialized by pickle and send with bytes")
103-
obj = pickle.dumps(result)
104-
if stream:
105-
request = Bytes(data=b"STREAM"+obj)
106-
response = stub.toBytesStream(request)
107-
else:
108-
request = Bytes(data=b"PICKLE"+obj)
108+
logger.info("Data will serialized by pickle and send with bytes")
109+
obj = pickle.dumps(result)
110+
request = Bytes(data=b"PICKLE" + obj)
109111
response = pickle.loads(stub.toBytes(request).data)
112+
logger.debug(f"type of data loaded is {type(response)}")
113+
return response
114+
else:
115+
116+
def pack(data):
117+
for i in data:
118+
_obj = pickle.dumps(i)
119+
yield Bytes(data=b"PICKLE" + _obj)
110120

111-
logger.debug(f"type of data loaded is {type(response)}")
121+
def unpack(data):
122+
for i in data:
123+
_data = i.data
124+
is_pickled = True if b"PICKLE" in _data[:6] else False
125+
if is_pickled:
126+
_data = pickle.loads(_data[6:])
127+
yield _data
112128

113-
return response
129+
return unpack(stub.toBytesStream(pack(result)))
114130

115131
if func is not None:
116132
return partial(wrapper, func)

tzrpc/decorator/rpc.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# @File : rpc.py
66
import pickle
77
from functools import partial
8-
98
from tzrpc.proto.py.Boolean_pb2 import Boolean
109
from tzrpc.proto.py.Bytes_pb2 import Bytes
1110
from tzrpc.proto.py.Number_pb2 import Double, Float, Integer
@@ -16,15 +15,14 @@
1615

1716
servicers = []
1817

19-
logger = get_logger(to_std=True, stdout_level="INFO", save_log_file=False)
18+
logger = get_logger(to_std=True, stdout_level="DEBUG", save_log_file=False)
2019

2120

2221
class RpcServicer:
2322
def __init__(self):
2423
pass
2524

2625
def register(self, task=None, stream=False):
27-
2826
if task is not None:
2927
_listener = Listener(task)
3028
# print(_listener)
@@ -119,20 +117,21 @@ def toNdarrays(self, request, context):
119117
response = ndarrays(ndarray=_ndarray)
120118
return response
121119

122-
def toBytesStream(self, request, context):
123-
logger.debug(f"Method toBytes({request}, {context}) called.")
124-
data = request.data
125-
is_pickled = True if b"PICKLE" in data[:6] else False
126-
if is_pickled:
127-
data = pickle.loads(data[6:])
128-
result = self.task(data)
129-
if is_pickled:
130-
result = pickle.dumps(result)
120+
def toBytesStream(self, request, context):
121+
logger.debug(f"Method toBytesStream({request}, {context}) called.")
131122

132-
response = Bytes(data=result)
133-
return response
123+
for data in request:
124+
data = data.data
125+
is_pickled = True if b"PICKLE" in data[:6] else False
126+
if is_pickled:
127+
data = pickle.loads(data[6:])
128+
result = self.task(data)
129+
130+
for i in result:
131+
_result = pickle.dumps(i)
132+
yield Bytes(data=b"PICKLE" + _result)
134133

135-
def toNdarrayStream(self, request, context):
134+
def toNdarrayStream(self, request, context):
136135
pass
137136

138137
def toTensor(self, request, context):

0 commit comments

Comments
 (0)