Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct file response for internal service calling #5071

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/_bentoml_impl/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

if t.TYPE_CHECKING:
from httpx._types import RequestFiles
from PIL import Image

from _bentoml_sdk import Service
from bentoml._internal.external_typing import ASGIApp
Expand Down Expand Up @@ -358,10 +359,11 @@ def is_file_field(k: str) -> bool:

def _deserialize_output(self, payload: Payload, endpoint: ClientEndpoint) -> t.Any:
data = iter(payload.data)
if (ot := endpoint.output.get("type")) == "string":
return bytes(next(data)).decode("utf-8")
elif ot == "bytes":
return bytes(next(data))
if (endpoint.output.get("type")) == "string":
content = bytes(next(data))
if endpoint.output.get("format") == "binary":
return content
return content.decode("utf-8")
elif endpoint.output_spec is not None:
model = self.serde.deserialize_model(payload, endpoint.output_spec)
if isinstance(model, RootModel):
Expand Down Expand Up @@ -470,10 +472,8 @@ def _call(
raise map_exception(resp)
if endpoint.stream_output:
return self._parse_stream_response(endpoint, resp)
elif (
endpoint.output.get("type") == "file"
and self.media_type == "application/json"
):
elif endpoint.output.get("type") == "file":
# file responses are always raw binaries whatever the serde is
return self._parse_file_response(endpoint, resp)
else:
return self._parse_response(endpoint, resp)
Expand All @@ -497,11 +497,18 @@ def _parse_stream_response(

def _parse_file_response(
self, endpoint: ClientEndpoint, resp: httpx.Response
) -> pathlib.Path:
) -> pathlib.Path | Image.Image:
from multipart.multipart import parse_options_header
from PIL import Image

content_disposition = resp.headers.get("content-disposition")
content_type = resp.headers.get("content-type", "")
filename: str | None = None
if endpoint.output.get("pil"):
image_formats = (
[content_type[6:]] if content_type.startswith("image/") else None
)
return Image.open(io.BytesIO(resp.read()), formats=image_formats)
if content_disposition:
_, options = parse_options_header(content_disposition)
if b"filename" in options:
Expand Down Expand Up @@ -589,10 +596,8 @@ async def _call(
raise map_exception(resp)
if endpoint.stream_output:
return self._parse_stream_response(endpoint, resp)
elif (
endpoint.output.get("type") == "file"
and self.media_type == "application/json"
):
elif endpoint.output.get("type") == "file":
# file responses are always raw binaries whatever the serde is
return await self._parse_file_response(endpoint, resp)
else:
return await self._parse_response(endpoint, resp)
Expand All @@ -618,11 +623,18 @@ async def _parse_stream_response(

async def _parse_file_response(
self, endpoint: ClientEndpoint, resp: httpx.Response
) -> pathlib.Path:
) -> pathlib.Path | Image.Image:
from multipart.multipart import parse_options_header
from PIL import Image

content_disposition = resp.headers.get("content-disposition")
content_type = resp.headers.get("content-type", "")
filename: str | None = None
if endpoint.output.get("pil"):
image_formats = (
[content_type[6:]] if content_type.startswith("image/") else None
)
return Image.open(io.BytesIO(await resp.aread()), formats=image_formats)
if content_disposition:
_, options = parse_options_header(content_disposition)
if b"filename" in options:
Expand Down
4 changes: 1 addition & 3 deletions src/_bentoml_sdk/io_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,6 @@ async def to_http_response(cls, obj: t.Any, serde: Serde) -> Response:
from starlette.responses import Response
from starlette.responses import StreamingResponse

from _bentoml_impl.serde import JSONSerde

if inspect.isasyncgen(obj):
try:
# try if there is any error before the first yield
Expand Down Expand Up @@ -248,7 +246,7 @@ def content_stream() -> t.Generator[str | bytes, None, None]:
headers=payload.headers,
)
else:
if is_file_type(type(obj)) and isinstance(serde, JSONSerde):
if is_file_type(type(obj)):
if isinstance(obj, pathlib.PurePath):
media_type = mimetypes.guess_type(obj)[0] or cls.mime_type()
should_inline = media_type.startswith("image")
Expand Down
2 changes: 1 addition & 1 deletion src/_bentoml_sdk/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __get_pydantic_json_schema__(
) -> dict[str, t.Any]:
value = handler(schema)
if handler.mode == "validation":
value.update({"type": "file", "format": "image"})
value.update({"type": "file", "format": "image", "pil": True})
else:
value.update({"type": "string", "format": "binary"})
return value
Expand Down
Loading