Skip to content

Commit 612d25c

Browse files
authored
add tests for polars decorators (#1615)
Signed-off-by: cosmicBboy <[email protected]>
1 parent b11cc4d commit 612d25c

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

pandera/decorators.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@
2727

2828
from pandera import errors
2929
from pandera.api.base.error_handler import ErrorHandler
30-
from pandera.api.pandas.array import SeriesSchema
31-
from pandera.api.pandas.container import DataFrameSchema
32-
from pandera.api.pandas.model import DataFrameModel
30+
from pandera.api.dataframe.components import ComponentSchema
31+
from pandera.api.dataframe.container import DataFrameSchema
32+
from pandera.api.dataframe.model import DataFrameModel
3333
from pandera.inspection_utils import (
3434
is_classmethod_from_meta,
3535
is_decorated_classmethod,
3636
)
3737
from pandera.typing import AnnotationInfo
3838
from pandera.validation_depth import validation_type
3939

40-
Schemas = Union[DataFrameSchema, SeriesSchema]
40+
Schemas = Union[DataFrameSchema, ComponentSchema]
4141
InputGetter = Union[str, int]
4242
OutputGetter = Union[str, int, Callable]
4343
F = TypeVar("F", bound=Callable)
@@ -84,7 +84,7 @@ def _get_fn_argnames(fn: Callable) -> List[str]:
8484
def _handle_schema_error(
8585
decorator_name,
8686
fn: Callable,
87-
schema: Union[DataFrameSchema, SeriesSchema],
87+
schema: Union[DataFrameSchema, ComponentSchema],
8888
data_obj: Any,
8989
schema_error: errors.SchemaError,
9090
) -> NoReturn:
@@ -110,7 +110,7 @@ def _handle_schema_error(
110110
def _parse_schema_error(
111111
decorator_name,
112112
fn: Callable,
113-
schema: Union[DataFrameSchema, SeriesSchema],
113+
schema: Union[DataFrameSchema, ComponentSchema],
114114
data_obj: Any,
115115
schema_error: errors.SchemaError,
116116
reason_code: errors.SchemaErrorReason,
@@ -355,7 +355,7 @@ def check_output(
355355
# pylint: disable=too-many-boolean-expressions
356356
if callable(obj_getter) and (
357357
schema.coerce
358-
or (schema.index is not None and schema.index.coerce)
358+
or (schema.index is not None and schema.index.coerce) # type: ignore[union-attr]
359359
or (
360360
isinstance(schema, DataFrameSchema)
361361
and any(col.coerce for col in schema.columns.values())
@@ -490,7 +490,7 @@ def _wrapper(
490490
out_schemas = out
491491
if isinstance(out, list):
492492
out_schemas = out
493-
elif isinstance(out, (DataFrameSchema, SeriesSchema)):
493+
elif isinstance(out, (DataFrameSchema, ComponentSchema)):
494494
out_schemas = [(None, out)] # type: ignore
495495
elif isinstance(out, tuple):
496496
out_schemas = [out]

pandera/typing/polars.py

+7
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ class LazyFrame(DataFrameBase, pl.LazyFrame, Generic[T]):
3535
*new in 0.19.0*
3636
"""
3737

38+
class DataFrame(DataFrameBase, pl.DataFrame, Generic[T]):
39+
"""
40+
Pandera generic for pl.LazyFrame, only used for type annotation.
41+
42+
*new in 0.19.0*
43+
"""
44+
3845
# pylint: disable=too-few-public-methods
3946
class Series(SeriesBase, pl.Series, Generic[T]):
4047
"""
+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Unit tests for using schemas with polars and function decorators."""
2+
3+
import polars as pl
4+
import pytest
5+
6+
import pandera.polars as pa
7+
import pandera.typing.polars as pa_typing
8+
9+
10+
@pytest.fixture
11+
def data() -> pl.DataFrame:
12+
return pl.DataFrame({"a": [1, 2, 3]})
13+
14+
15+
@pytest.fixture
16+
def invalid_data(data) -> pl.DataFrame:
17+
return data.rename({"a": "b"})
18+
19+
20+
def test_polars_dataframe_check_io(data, invalid_data):
21+
# pylint: disable=unused-argument
22+
23+
schema = pa.DataFrameSchema({"a": pa.Column(int)})
24+
25+
@pa.check_input(schema)
26+
def fn_check_input(x):
27+
...
28+
29+
@pa.check_output(schema)
30+
def fn_check_output(x):
31+
return x
32+
33+
@pa.check_io(x=schema, out=schema)
34+
def fn_check_io(x):
35+
return x
36+
37+
@pa.check_io(x=schema, out=schema)
38+
def fn_check_io_invalid(x):
39+
return x.rename({"a": "b"})
40+
41+
# valid data should pass
42+
fn_check_input(data)
43+
fn_check_output(data)
44+
fn_check_io(data)
45+
46+
# invalid data or invalid function should not pass
47+
with pytest.raises(pa.errors.SchemaError):
48+
fn_check_input(invalid_data)
49+
50+
with pytest.raises(pa.errors.SchemaError):
51+
fn_check_output(invalid_data)
52+
53+
with pytest.raises(pa.errors.SchemaError):
54+
fn_check_io_invalid(data)
55+
56+
57+
def test_polars_dataframe_check_types(data, invalid_data):
58+
# pylint: disable=unused-argument
59+
60+
class Model(pa.DataFrameModel):
61+
a: int
62+
63+
@pa.check_types
64+
def fn_check_input(x: pa_typing.DataFrame[Model]):
65+
...
66+
67+
@pa.check_types
68+
def fn_check_output(x) -> pa_typing.DataFrame[Model]:
69+
return x
70+
71+
@pa.check_types
72+
def fn_check_io(
73+
x: pa_typing.DataFrame[Model],
74+
) -> pa_typing.DataFrame[Model]:
75+
return x
76+
77+
@pa.check_types
78+
def fn_check_io_invalid(
79+
x: pa_typing.DataFrame[Model],
80+
) -> pa_typing.DataFrame[Model]:
81+
return x.rename({"a": "b"})
82+
83+
# valid data should pass
84+
fn_check_input(data)
85+
fn_check_output(data)
86+
fn_check_io(data)
87+
88+
# invalid data or invalid function should not pass
89+
with pytest.raises(pa.errors.SchemaError):
90+
fn_check_input(invalid_data)
91+
92+
with pytest.raises(pa.errors.SchemaError):
93+
fn_check_output(invalid_data)
94+
95+
with pytest.raises(pa.errors.SchemaError):
96+
fn_check_io_invalid(data)

0 commit comments

Comments
 (0)