Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pandas): add support for serializing pd.DataFrame in Arrow IPC formats #4779

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions src/bentoml/_internal/io_descriptors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")
pd = LazyLoader("pd", globals(), "pandas", exc_msg=EXC_MSG)
np = LazyLoader("np", globals(), "numpy")
pyarrow = LazyLoader("pyarrow", globals(), "pyarrow")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,6 +145,8 @@ def _series_openapi_schema(
class SerializationFormat(Enum):
JSON = "application/json"
PARQUET = "application/octet-stream"
ARROW_FILE = "application/vnd.apache.arrow.file"
ARROW_STREAM = "application/vnd.apache.arrow.stream"
CSV = "text/csv"

def __init__(self, mime_type: str):
Expand All @@ -156,6 +159,10 @@ def __str__(self) -> str:
return "parquet"
elif self == SerializationFormat.CSV:
return "csv"
elif self == SerializationFormat.ARROW_FILE:
return "arrow_file"
elif self == SerializationFormat.ARROW_STREAM:
return "arrow_stream"
else:
raise ValueError(f"Unknown serialization format: {self}")

Expand All @@ -173,6 +180,10 @@ def _infer_serialization_format_from_request(
return SerializationFormat.PARQUET
elif content_type == "text/csv":
return SerializationFormat.CSV
elif content_type == "application/vnd.apache.arrow.file":
return SerializationFormat.ARROW_FILE
elif content_type == "application/vnd.apache.arrow.stream":
return SerializationFormat.ARROW_STREAM
elif content_type:
logger.debug(
"Unknown Content-Type ('%s'), falling back to '%s' serialization format.",
Expand All @@ -196,6 +207,13 @@ def _validate_serialization_format(serialization_format: SerializationFormat):
raise MissingDependencyException(
"Parquet serialization is not available. Try installing pyarrow or fastparquet first."
)
if (
serialization_format is SerializationFormat.ARROW_FILE
or serialization_format is SerializationFormat.ARROW_STREAM
) and find_spec("pyarrow") is None:
raise MissingDependencyException(
"Arrow serialization is not available. Try installing pyarrow first."
)


class PandasDataFrame(
Expand Down Expand Up @@ -311,6 +329,8 @@ def predict(input_df: pd.DataFrame) -> pd.DataFrame:
- :obj:`json` - JSON text format (inferred from content-type ``"application/json"``)
- :obj:`parquet` - Parquet binary format (inferred from content-type ``"application/octet-stream"``)
- :obj:`csv` - CSV text format (inferred from content-type ``"text/csv"``)
- :obj:`arrow_file` - Arrow file format (inferred from content-type ``"application/vnd.apache.arrow.file"``)
- :obj:`arrow_stream` - Arrow stream format (inferred from content-type ``"application/vnd.apache.arrow.stream"``)

Returns:
:obj:`PandasDataFrame`: IO Descriptor that represents a :code:`pd.DataFrame`.
Expand All @@ -325,7 +345,13 @@ def __init__(
enforce_dtype: bool = False,
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
default_format: t.Literal["json", "parquet", "csv"] = "json",
default_format: t.Literal[
"json",
"parquet",
"csv",
"arrow_file",
"arrow_stream",
] = "json",
):
self._orient: ext.DataFrameOrient = orient
self._columns = columns
Expand Down Expand Up @@ -371,6 +397,8 @@ def _from_sample(self, sample: ext.PdDataFrame) -> ext.PdDataFrame:
- :obj:`json` - JSON text format (inferred from content-type ``"application/json"``)
- :obj:`parquet` - Parquet binary format (inferred from content-type ``"application/octet-stream"``)
- :obj:`csv` - CSV text format (inferred from content-type ``"text/csv"``)
- :obj:`arrow_file` - Arrow file format (inferred from content-type ``"application/vnd.apache.arrow.file"``)
- :obj:`arrow_stream` - Arrow stream format (inferred from content-type ``"application/vnd.apache.arrow.stream"``)

Returns:
:class:`~bentoml._internal.io_descriptors.pandas.PandasDataFrame`: IODescriptor from given users inputs.
Expand Down Expand Up @@ -539,6 +567,12 @@ async def from_http_request(self, request: Request) -> ext.PdDataFrame:
res = pd.read_parquet(io.BytesIO(obj), engine=get_parquet_engine())
elif serialization_format is SerializationFormat.CSV:
res: ext.PdDataFrame = pd.read_csv(io.BytesIO(obj), dtype=dtype)
elif serialization_format is SerializationFormat.ARROW_FILE:
with pyarrow.ipc.open_file(obj) as reader:
res = reader.read_pandas()
elif serialization_format is SerializationFormat.ARROW_STREAM:
with pyarrow.ipc.open_stream(obj) as reader:
res = reader.read_pandas()
else:
raise InvalidArgument(
f"Unknown serialization format ({serialization_format})."
Expand Down Expand Up @@ -576,6 +610,18 @@ async def to_http_response(
resp = obj.to_parquet(engine=get_parquet_engine())
elif serialization_format is SerializationFormat.CSV:
resp = obj.to_csv()
elif serialization_format is SerializationFormat.ARROW_FILE:
sink = pyarrow.BufferOutputStream()
batch = self.to_arrow(obj)
with pyarrow.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
resp = sink.getvalue().to_pybytes()
elif serialization_format is SerializationFormat.ARROW_STREAM:
sink = pyarrow.BufferOutputStream()
batch = self.to_arrow(obj)
with pyarrow.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
resp = sink.getvalue().to_pybytes()
else:
raise InvalidArgument(
f"Unknown serialization format ({serialization_format})."
Expand Down Expand Up @@ -743,7 +789,7 @@ def from_arrow(self, batch: pyarrow.RecordBatch) -> ext.PdDataFrame:
def to_arrow(self, df: pd.Series[t.Any]) -> pyarrow.RecordBatch:
import pyarrow

return pyarrow.RecordBatch.from_pandas(df)
return pyarrow.RecordBatch.from_pandas(df, preserve_index=True)

def spark_schema(self) -> pyspark.sql.types.StructType:
from pyspark.pandas.typedef import as_spark_type
Expand Down Expand Up @@ -1201,7 +1247,7 @@ def to_arrow(self, series: pd.Series[t.Any]) -> pyarrow.RecordBatch:
import pyarrow

df = series.to_frame()
return pyarrow.RecordBatch.from_pandas(df)
return pyarrow.RecordBatch.from_pandas(df, preserve_index=True)

def spark_schema(self) -> pyspark.sql.types.StructType:
from pyspark.pandas.typedef import as_spark_type
Expand Down
31 changes: 31 additions & 0 deletions tests/e2e/bento_server_http/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple

import numpy as np
import pyarrow
import pytest

from bentoml.client import AsyncHTTPClient
Expand Down Expand Up @@ -144,6 +145,36 @@ async def test_pandas(host: str):
assert response.status_code == 200
assert await response.aread() == b'[{"col1":202}]'

headers = {
"Content-Type": "application/vnd.apache.arrow.stream",
"Origin": ORIGIN,
}
sink = pyarrow.BufferOutputStream()
batch = pyarrow.RecordBatch.from_pandas(df, preserve_index=True)
with pyarrow.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
data = sink.getvalue().to_pybytes()
response = await client.client.post(
"/predict_dataframe", headers=headers, data=data
)
assert response.status_code == 200
assert await response.aread() == b'[{"col1":202}]'

headers = {
"Content-Type": "application/vnd.apache.arrow.file",
"Origin": ORIGIN,
}
sink = pyarrow.BufferOutputStream()
batch = pyarrow.RecordBatch.from_pandas(df, preserve_index=True)
with pyarrow.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
data = sink.getvalue().to_pybytes()
response = await client.client.post(
"/predict_dataframe", headers=headers, data=data
)
assert response.status_code == 200
assert await response.aread() == b'[{"col1":202}]'


@pytest.mark.asyncio
async def test_file(host: str, bin_file: str):
Expand Down
Loading