Skip to content

Commit 63140c9

Browse files
authored
bugfix: nullable check float dtype handles nan and null (#1627)
Signed-off-by: cosmicBboy <[email protected]>
1 parent 0faae07 commit 63140c9

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

pandera/backends/polars/components.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,9 @@ def check_nullable(
207207
)
208208
]
209209

210+
expr = pl.col(schema.selector).is_not_null()
210211
if is_float_dtype(check_obj, schema.selector):
211-
expr = pl.col(schema.selector).is_not_nan()
212-
else:
213-
expr = pl.col(schema.selector).is_not_null()
212+
expr = expr & pl.col(schema.selector).is_not_nan()
214213

215214
isna = check_obj.select(expr)
216215
passed = isna.select([pl.col("*").all()]).collect()

tests/polars/test_polars_components.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_coerce_dtype(data, from_dtype, to_dtype, exception_cls):
129129
NULLABLE_DTYPES_AND_DATA = [
130130
[pl.Int64, [1, 2, 3, None]],
131131
[pl.Utf8, ["foo", "bar", "baz", None]],
132-
[pl.Float64, [1.0, 2.0, 3.0, float("nan")]],
132+
[pl.Float64, [1.0, 2.0, 3.0, float("nan"), None]],
133133
[pl.Boolean, [True, False, True, None]],
134134
]
135135

@@ -138,7 +138,7 @@ def test_coerce_dtype(data, from_dtype, to_dtype, exception_cls):
138138
@pytest.mark.parametrize("nullable", [True, False])
139139
def test_check_nullable(dtype, data, nullable):
140140
data = pl.LazyFrame({"column": pl.Series(data, dtype=dtype)})
141-
column_schema = pa.Column(pl.Int64, nullable=nullable, name="column")
141+
column_schema = pa.Column(dtype, nullable=nullable, name="column")
142142
backend = ColumnBackend()
143143
check_results: List[CoreCheckResult] = backend.check_nullable(
144144
data, column_schema
@@ -153,9 +153,7 @@ def test_check_nullable_regex(dtype, data, nullable):
153153
data = pl.LazyFrame(
154154
{f"column_{i}": pl.Series(data, dtype=dtype) for i in range(3)}
155155
)
156-
column_schema = pa.Column(
157-
pl.Int64, nullable=nullable, name=r"^column_\d+$"
158-
)
156+
column_schema = pa.Column(dtype, nullable=nullable, name=r"^column_\d+$")
159157
backend = ColumnBackend()
160158
check_results = backend.check_nullable(data, column_schema)
161159
for result in check_results:

0 commit comments

Comments
 (0)