From d3462bd38d73ccc77d6767bcfb006004853a7f06 Mon Sep 17 00:00:00 2001 From: Ghislain Piot Date: Thu, 6 Jun 2024 18:04:23 +0200 Subject: [PATCH] Correct some of the Scikit-learn stubs (#311) --- stubs/sklearn/metrics/_classification.pyi | 2 +- stubs/sklearn/model_selection/_split.pyi | 2 +- stubs/sklearn/utils/__init__.pyi | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/stubs/sklearn/metrics/_classification.pyi b/stubs/sklearn/metrics/_classification.pyi index 80d06354..94b95746 100644 --- a/stubs/sklearn/metrics/_classification.pyi +++ b/stubs/sklearn/metrics/_classification.pyi @@ -112,7 +112,7 @@ def precision_recall_fscore_support( labels: None | ArrayLike = None, pos_label: str | int = 1, average: None | Literal["binary", "micro", "macro", "samples", "weighted"] = None, - warn_for: set | tuple = ..., + warn_for: list | set | tuple = ..., sample_weight: None | ArrayLike = None, zero_division: Literal["warn", "warn"] | int = "warn", ) -> tuple[float | ndarray, float | ndarray, float | ndarray, None | ndarray]: ... diff --git a/stubs/sklearn/model_selection/_split.pyi b/stubs/sklearn/model_selection/_split.pyi index f0841a13..885d8fa8 100644 --- a/stubs/sklearn/model_selection/_split.pyi +++ b/stubs/sklearn/model_selection/_split.pyi @@ -204,7 +204,7 @@ class _CVIterableWrapper(BaseCrossValidator): def split(self, X: Any = None, y: Any = None, groups: Any = None): ... def check_cv( - cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator = 5, + cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator | None = 5, y: None | ArrayLike = None, *, classifier: bool = False, diff --git a/stubs/sklearn/utils/__init__.pyi b/stubs/sklearn/utils/__init__.pyi index 5d2d6ed9..a1068c56 100644 --- a/stubs/sklearn/utils/__init__.pyi +++ b/stubs/sklearn/utils/__init__.pyi @@ -83,8 +83,10 @@ def resample( n_samples: None | Int = None, random_state: RandomState | None | Int = None, stratify: None | MatrixLike | ArrayLike = None, -) -> list[ndarray]: ... -def shuffle(*arrays, random_state: RandomState | None | Int = None, n_samples: None | Int = None) -> list[SupportsIndex]: ... +) -> list[ndarray] | None: ... +def shuffle( + *arrays, random_state: RandomState | None | Int = None, n_samples: None | Int = None +) -> list[SupportsIndex] | None: ... def safe_sqr(X: MatrixLike | ArrayLike, *, copy: bool = True) -> ndarray: ... def gen_batches(n: Int, batch_size: Int, *, min_batch_size: Int = 0) -> Iterator[slice]: ... def gen_even_slices(n: Int, n_packs: Int, *, n_samples: None | Int = None) -> Iterator[slice]: ...