Skip to content

Commit

Permalink
Added streaming endpoint support for streaming client (#2139)
Browse files Browse the repository at this point in the history
Co-authored-by: Dariusz Trawinski <[email protected]>
  • Loading branch information
2 people authored and dkalinowski committed Nov 16, 2023
1 parent e8125c5 commit 9a30c8d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 47 deletions.
87 changes: 60 additions & 27 deletions demos/common/stream_client/stream_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,26 @@ def write(self, frame):
def release(self):
self.cv_sink.release()

class ImshowOutputBackend(OutputBackend):
def init(self, sink, fps, width, height):
...
def write(self, frame):
cv2.imshow("OVMS StreamClient", frame)
cv2.waitKey(1)
def release(self):
cv2.destroyAllWindows()

class StreamClient:
class OutputBackends():
ffmpeg = FfmpegOutputBackend()
cv2 = CvOutputBackend()
imshow = ImshowOutputBackend()
none = OutputBackend()
class Datatypes():
fp32 = FP32()
uint8 = UINT8()

def __init__(self, *, preprocess_callback = None, postprocess_callback, source, sink : str, ffmpeg_output_width = None, ffmpeg_output_height = None, output_backend :OutputBackend = OutputBackends.ffmpeg, verbose : bool = False, exact : bool = True, benchmark : bool = False):
def __init__(self, *, preprocess_callback = None, postprocess_callback, source, sink: str, ffmpeg_output_width = None, ffmpeg_output_height = None, output_backend: OutputBackend = OutputBackends.ffmpeg, verbose: bool = False, exact: bool = True, benchmark: bool = False, max_inflight_packets: int = 4):
"""
Parameters
----------
Expand Down Expand Up @@ -114,6 +124,7 @@ def __init__(self, *, preprocess_callback = None, postprocess_callback, source,
self.benchmark = benchmark

self.pq = queue.PriorityQueue()
self.req_q = queue.Queue(max_inflight_packets)

def grab_frame(self):
success, frame = self.cap.read()
Expand All @@ -132,18 +143,24 @@ def grab_frame(self):
dropped_frames = 0
frames = 0
def callback(self, frame, i, timestamp, result, error):
if error is not None:
if self.benchmark:
self.dropped_frames += 1
if self.verbose:
print(error)
if i == None:
i = result.get_response().parameters["OVMS_MP_TIMESTAMP"].int64_param
if timestamp == None:
timestamp = result.get_response().parameters["OVMS_MP_TIMESTAMP"].int64_param
frame = self.postprocess_callback(frame, result)
self.pq.put((i, frame, timestamp))
if error is not None and self.verbose == True:
print(error)
self.req_q.get()

def display(self):
i = 0
while True:
if self.pq.empty():
continue
entry = self.pq.get()
if (entry[0] == i and self.exact) or (entry[0] > i and self.exact is not True):
if (entry[0] == i and self.exact and self.streaming_api is not True) or (entry[0] > i and (self.exact is not True or self.streaming_api is True)):
if isinstance(entry[1], str) and entry[1] == "EOS":
break
frame = entry[1]
Expand All @@ -161,8 +178,10 @@ def display(self):
elif self.exact:
self.pq.put(entry)

def get_timestamp(self) -> int:
return int(cv2.getTickCount() / cv2.getTickFrequency() * 1e6)

def start(self, *, ovms_address : str, input_name : str, model_name : str, datatype : Datatype = FP32(), batch = True, limit_stream_duration : int = 0, limit_frames : int = 0):
def start(self, *, ovms_address : str, input_name : str, model_name : str, datatype : Datatype = FP32(), batch = True, limit_stream_duration : int = 0, limit_frames : int = 0, streaming_api: bool = False):
"""
Parameters
----------
Expand All @@ -180,12 +199,15 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
Limits how long client could run
limit_frames : int
Limits how many frames should be processed
streaming_api : bool
Use experimental streaming endpoint
"""

self.cap = cv2.VideoCapture(self.source, cv2.CAP_ANY)
self.cap = cv2.VideoCapture(int(self.source) if len(self.source) == 1 and self.source[0].isdigit() else self.source, cv2.CAP_ANY)
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 0)
fps = self.cap.get(cv2.CAP_PROP_FPS)
triton_client = grpcclient.InferenceServerClient(url=ovms_address, verbose=False)
self.streaming_api = streaming_api

display_th = threading.Thread(target=self.display)
display_th.start()
Expand All @@ -199,26 +221,37 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
if self.height is None:
self.height = np_test_frame.shape[0]
self.output_backend.init(self.sink, fps, self.width, self.height)

if streaming_api:
triton_client.start_stream(partial(self.callback, None, None, None))

i = 0
frame_number = 0
total_time_start = time.time()
while not self.force_exit:
timestamp = time.time()
frame = self.grab_frame()
if frame is not None:
np_frame = np.array([frame], dtype=datatype.dtype()) if batch else np.array(frame, dtype=datatype.dtype())
inputs=[grpcclient.InferInput(input_name, np_frame.shape, datatype.string())]
inputs[0].set_data_from_numpy(np_frame)
triton_client.async_infer(
model_name=model_name,
callback=partial(self.callback, frame, i, timestamp),
inputs=inputs)
i += 1
if limit_stream_duration > 0 and time.time() - total_time_start > limit_stream_duration:
break
if limit_frames > 0 and i > limit_frames:
break
self.pq.put((i, "EOS"))
try:
while not self.force_exit:
self.req_q.put(frame_number)
timestamp = time.time()
frame = self.grab_frame()
if frame is not None:
np_frame = np.array([frame], dtype=datatype.dtype()) if batch else np.array(frame, dtype=datatype.dtype())
inputs=[grpcclient.InferInput(input_name, np_frame.shape, datatype.string())]
inputs[0].set_data_from_numpy(np_frame)
if streaming_api:
triton_client.async_stream_infer(model_name=model_name, inputs=inputs, parameters={"OVMS_MP_TIMESTAMP":self.get_timestamp()}, request_id=str(frame_number))
else:
triton_client.async_infer(
model_name=model_name,
callback=partial(self.callback, frame, frame_number, timestamp),
inputs=inputs)
frame_number += 1
if limit_stream_duration > 0 and time.time() - total_time_start > limit_stream_duration:
break
if limit_frames > 0 and frame_number > limit_frames:
break
finally:
self.pq.put((frame_number, "EOS"))
if streaming_api:
triton_client.stop_stream()
sent_all_frames = time.time() - total_time_start


Expand All @@ -227,4 +260,4 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
self.output_backend.release()
total_time = time.time() - total_time_start
if self.benchmark:
print(f"{{\"inference_time\": {sum(self.inference_time)/i}, \"dropped_frames\": {self.dropped_frames}, \"frames\": {self.frames}, \"fps\": {self.frames/total_time}, \"total_time\": {total_time}, \"sent_all_frames\": {sent_all_frames}}}")
print(f"{{\"inference_time\": {sum(self.inference_time)/frame_number}, \"dropped_frames\": {self.dropped_frames}, \"frames\": {self.frames}, \"fps\": {self.frames/total_time}, \"total_time\": {total_time}, \"sent_all_frames\": {sent_all_frames}}}")
20 changes: 1 addition & 19 deletions demos/mediapipe/holistic_tracking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ This guide shows how to implement [MediaPipe](../../../docs/mediapipe.md) graph

Example usage of graph that accepts Mediapipe::ImageFrame as a input:

The demo is based on the [upstream Mediapipe holistic demo](https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md)
and [Mediapipe Iris demo](https://github.com/google/mediapipe/blob/master/docs/solutions/iris.md)
The demo is based on the [upstream Mediapipe holistic demo](https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md).

## Prepare the server deployment

Expand Down Expand Up @@ -82,23 +81,6 @@ Results saved to :image_0.jpg
## Output image
![output](output_image.jpg)

## Run client application for iris tracking
In a similar way can be executed the iris image analysis:

```bash
python mediapipe_holistic_tracking.py --graph_name irisTracking --images_list input_images.txt --grpc_port 9000
Running demo application.
Start processing:
Graph name: irisTracking
(640, 960, 3)
Iteration 0; Processing time: 77.03 ms; speed 12.98 fps
Results saved to :image_0.jpg
```

## Output image
![output](output_image1.jpg)



## RTSP Client
Mediapipe graph can be used for remote analysis of individual images but the client can use it for a complete video stream processing.
Expand Down
2 changes: 1 addition & 1 deletion demos/mediapipe/holistic_tracking/rtsp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def postprocess(frame, result):
exact = True

client = StreamClient(postprocess_callback = postprocess, preprocess_callback=preprocess, output_backend=backend, source=args.input_stream, sink=args.output_stream, exact=exact, benchmark=args.benchmark, verbose=args.verbose)
client.start(ovms_address=args.grpc_address, input_name=args.input_name, model_name=args.model_name, datatype = StreamClient.Datatypes.uint8, batch = False, limit_stream_duration = args.limit_stream_duration, limit_frames = args.limit_frames)
client.start(ovms_address=args.grpc_address, input_name=args.input_name, model_name=args.model_name, datatype = StreamClient.Datatypes.uint8, batch = False, limit_stream_duration = args.limit_stream_duration, limit_frames = args.limit_frames, streaming_api=True)

0 comments on commit 9a30c8d

Please sign in to comment.