Skip to content

Commit 9667234

Browse files
authored
feat: add overload. (#1823)
* feat: add overload. Signed-off-by: yassun7010 <[email protected]> * run pre-commit Signed-off-by: yassun7010 <[email protected]> * test: add test cases. Signed-off-by: yassun7010 <[email protected]> * fix: add pylint disable. Signed-off-by: yassun7010 <[email protected]> * fix: run black format. Signed-off-by: yassun7010 <[email protected]> --------- Signed-off-by: yassun7010 <[email protected]>
1 parent a0ac2a1 commit 9667234

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

pandera/dtypes.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
Union,
1919
)
2020

21+
from typing_extensions import overload
22+
2123
try:
2224
# python 3.8+
2325
from typing import Literal # type: ignore[attr-defined]
@@ -89,6 +91,19 @@ def __hash__(self) -> int:
8991
_DataTypeClass = Type[_Dtype]
9092

9193

94+
@overload
95+
def immutable(
96+
pandera_dtype_cls: _DataTypeClass, # pylint: disable=W0613
97+
**dataclass_kwargs: Any, # pylint: disable=W0613
98+
) -> _DataTypeClass: ...
99+
100+
101+
@overload
102+
def immutable(
103+
**dataclass_kwargs: Any, # pylint: disable=W0613
104+
) -> Callable[[_DataTypeClass], _DataTypeClass]: ...
105+
106+
92107
def immutable(
93108
pandera_dtype_cls: Optional[_DataTypeClass] = None, **dataclass_kwargs: Any
94109
) -> Union[_DataTypeClass, Callable[[_DataTypeClass], _DataTypeClass]]:

tests/core/test_model.py

+12
Original file line numberDiff line numberDiff line change
@@ -1515,3 +1515,15 @@ def sqrt(cls, series):
15151515
assert Schema.validate(df).equals( # type: ignore [attr-defined]
15161516
pd.DataFrame({"a": [11.0], "abc": [1.0], "cba": [200.0]})
15171517
)
1518+
1519+
1520+
def test_pandera_dtype() -> None:
1521+
class Schema(pa.DataFrameModel):
1522+
a: Series[pa.Float]
1523+
b: Series[pa.Int]
1524+
c: Series[pa.String]
1525+
1526+
df = pd.DataFrame({"a": [1.0], "b": [1], "c": ["1"]})
1527+
assert Schema.validate(df).equals( # type: ignore [attr-defined]
1528+
pd.DataFrame({"a": [1.0], "b": [1], "c": ["1"]})
1529+
)

0 commit comments

Comments
 (0)