Skip to content

Commit

Permalink
fix: return equal number of error payloads (#272)
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Feb 13, 2023
1 parent da135e5 commit c4ded3e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mosec"
version = "0.4.8"
version = "0.4.9"
authors = ["Keming <[email protected]>", "Zichen <[email protected]>"]
edition = "2021"
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion examples/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Inference(Worker):

def forward(self, data: List[float]) -> List[float]:
logger.info("sleeping for %s seconds", sum(data))
time.sleep(sum(data))
time.sleep(max(data))
return data


Expand Down
19 changes: 10 additions & 9 deletions mosec/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from multiprocessing.synchronize import Event
from typing import Any, Callable, Optional, Sequence, Tuple, Type

from .errors import DecodingError, ValidationError
from .errors import DecodingError, EncodingError, ValidationError
from .ipc import IPCWrapper
from .protocol import Protocol
from .worker import Worker
Expand Down Expand Up @@ -246,36 +246,37 @@ def coordinate(self):
break

# pylint: disable=broad-except
length = len(payloads)
try:
data = [decoder(item) for item in payloads]
data = (
self.worker.forward(data)
if self.worker.max_batch_size > 1
else (self.worker.forward(data[0]),)
)
if len(data) != len(payloads):
if len(data) != length:
raise ValueError(
"returned data size doesn't match the input data size:"
f"input({len(data)})!=output({len(payloads)})"
f"input({length})!=output({len(data)})"
)
status = self.protocol.FLAG_OK
payloads = [encoder(item) for item in data]
except DecodingError as err:
except (EncodingError, DecodingError) as err:
err_msg = str(err).replace("\n", " - ")
err_msg = err_msg if err_msg else "cannot deserialize request bytes"
logger.info("%s decoding error: %s", self.name, err_msg)
err_msg = err_msg if err_msg else "cannot se/deserialize request bytes"
logger.info("%s encoding/decoding error: %s", self.name, err_msg)
status = self.protocol.FLAG_BAD_REQUEST
payloads = (f"decoding error: {err_msg}".encode(),)
payloads = [f"encoding/decoding error: {err_msg}".encode()] * length
except ValidationError as err:
err_msg = str(err)
err_msg = err_msg if err_msg else "invalid data format"
logger.info("%s validation error: %s", self.name, err_msg)
status = self.protocol.FLAG_VALIDATION_ERROR
payloads = (f"validation error: {err_msg}".encode(),)
payloads = [f"validation error: {err_msg}".encode()] * length
except Exception:
logger.warning(traceback.format_exc().replace("\n", " "))
status = self.protocol.FLAG_INTERNAL_ERROR
payloads = ("inference internal error".encode(),)
payloads = ["inference internal error".encode()] * length

try:
protocol_send(status, ids, payloads)
Expand Down

0 comments on commit c4ded3e

Please sign in to comment.