Skip to content

Commit 44a9763

Browse files
authored
feat: add pyarrow list and struct to pandas engine (#1699)
* feat: add pyarrow list and struct to pandas engine Signed-off-by: Ajith Aravind <[email protected]> * test: add tests for pyarrow list and struct Signed-off-by: Ajith Aravind <[email protected]> * fix: linting errors Signed-off-by: Ajith Aravind <[email protected]> --------- Signed-off-by: Ajith Aravind <[email protected]>
1 parent c3011b5 commit 44a9763

File tree

3 files changed

+128
-1
lines changed

3 files changed

+128
-1
lines changed

pandera/engines/pandas_engine.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
List,
2323
NamedTuple,
2424
Optional,
25+
Tuple,
2526
Type,
2627
Union,
2728
cast,
@@ -1765,7 +1766,7 @@ class ArrowDecimal128(DataType, dtypes.Decimal):
17651766
precision: int = 28
17661767
scale: int = 0
17671768

1768-
def __post_init__(self) -> None:
1769+
def __post_init__(self):
17691770
type_ = pd.ArrowDtype(
17701771
pyarrow.decimal128(self.precision, self.scale)
17711772
)
@@ -1832,3 +1833,67 @@ def from_parametrized_dtype(
18321833
value_type=pyarrow_dtype.value_type, # type: ignore
18331834
ordered=pyarrow_dtype.ordered, # type: ignore
18341835
)
1836+
1837+
@Engine.register_dtype(
1838+
equivalents=[
1839+
pyarrow.list_,
1840+
pyarrow.ListType,
1841+
pyarrow.FixedSizeListType,
1842+
]
1843+
)
1844+
@immutable(init=True)
1845+
class ArrowList(DataType):
1846+
"""Semantic representation of a :class:`pyarrow.list_`."""
1847+
1848+
type: Optional[pd.ArrowDtype] = dataclasses.field(
1849+
default=None, init=False
1850+
)
1851+
value_type: Optional[Union[pyarrow.DataType, pyarrow.Field]] = (
1852+
pyarrow.string()
1853+
)
1854+
list_size: Optional[int] = -1
1855+
1856+
def __post_init__(self):
1857+
type_ = pd.ArrowDtype(
1858+
pyarrow.list_(self.value_type, self.list_size)
1859+
)
1860+
object.__setattr__(self, "type", type_)
1861+
1862+
@classmethod
1863+
def from_parametrized_dtype(
1864+
cls,
1865+
pyarrow_dtype: Union[pyarrow.ListType, pyarrow.FixedSizeListType],
1866+
):
1867+
try:
1868+
_dtype = cls(
1869+
value_type=pyarrow_dtype.value_type, # type: ignore
1870+
list_size=pyarrow_dtype.list_size, # type: ignore
1871+
)
1872+
except AttributeError:
1873+
_dtype = cls(value_type=pyarrow_dtype.value_type) # type: ignore
1874+
return _dtype
1875+
1876+
@Engine.register_dtype(equivalents=[pyarrow.struct, pyarrow.StructType])
1877+
@immutable(init=True)
1878+
class ArrowStruct(DataType):
1879+
"""Semantic representation of a :class:`pyarrow.struct`."""
1880+
1881+
type: Optional[pd.ArrowDtype] = dataclasses.field(
1882+
default=None, init=False
1883+
)
1884+
fields: Optional[
1885+
Union[
1886+
Iterable[Union[pyarrow.Field, Tuple[str, pyarrow.DataType]]],
1887+
Dict[str, pyarrow.DataType],
1888+
]
1889+
] = tuple()
1890+
1891+
def __post_init__(self):
1892+
type_ = pd.ArrowDtype(pyarrow.struct(self.fields))
1893+
object.__setattr__(self, "type", type_)
1894+
1895+
@classmethod
1896+
def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.StructType):
1897+
return cls(
1898+
fields=[pyarrow_dtype.field(i) for i in range(pyarrow_dtype.num_fields)] # type: ignore
1899+
)

tests/core/test_pandas_engine.py

+60
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import hypothesis.strategies as st
99
import numpy as np
1010
import pandas as pd
11+
import pyarrow
1112
import pytest
1213
import pytz
1314
from hypothesis import given
@@ -237,3 +238,62 @@ def test_pandas_date_coerce_dtype(to_df, data):
237238
assert (
238239
coerced_data.map(lambda x: isinstance(x, date)) | coerced_data.isna()
239240
).all()
241+
242+
243+
pandas_arrow_dtype_cases = (
244+
(
245+
pd.Series([["a", "b", "c"]]),
246+
pyarrow.list_(pyarrow.string()),
247+
),
248+
(
249+
pd.Series([["a", "b"]]),
250+
pyarrow.list_(pyarrow.string(), 2),
251+
),
252+
(
253+
pd.Series([{"foo": 1, "bar": "a"}]),
254+
pyarrow.struct(
255+
[
256+
("foo", pyarrow.int64()),
257+
("bar", pyarrow.string()),
258+
]
259+
),
260+
),
261+
)
262+
263+
264+
@pytest.mark.parametrize(("data", "dtype"), pandas_arrow_dtype_cases)
265+
def test_pandas_arrow_dtype(data, dtype):
266+
"""Test pyarrow dtype."""
267+
dtype = pandas_engine.Engine.dtype(dtype)
268+
269+
dtype.coerce(data)
270+
271+
272+
pandas_arrow_dtype_errors_cases = (
273+
(
274+
pd.Series([["a", "b", "c"]]),
275+
pyarrow.list_(pyarrow.int64()),
276+
),
277+
(
278+
pd.Series([["a", "b"]]),
279+
pyarrow.list_(pyarrow.string(), 3),
280+
),
281+
(
282+
pd.Series([{"foo": 1, "bar": "a"}]),
283+
pyarrow.struct(
284+
[
285+
("foo", pyarrow.string()),
286+
("bar", pyarrow.int64()),
287+
]
288+
),
289+
),
290+
)
291+
292+
293+
@pytest.mark.parametrize(("data", "dtype"), pandas_arrow_dtype_errors_cases)
294+
def test_pandas_arrow_dtype_errors(data, dtype):
295+
"""Test pyarrow dtype raises ArrowInvalid or ArrowTypeError on bad data."""
296+
dtype = pandas_engine.Engine.dtype(dtype)
297+
298+
with pytest.raises((pyarrow.ArrowInvalid, pyarrow.ArrowTypeError)):
299+
dtype.coerce(data)

tests/strategies/test_strategies.py

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
pandas_engine.ArrowUInt16,
6565
pandas_engine.ArrowUInt32,
6666
pandas_engine.ArrowUInt64,
67+
pandas_engine.ArrowList,
68+
pandas_engine.ArrowStruct,
6769
]
6870
)
6971

0 commit comments

Comments
 (0)