Skip to content

Commit 93f6557

Browse files
committed
Add tests for sklearn normalize overloads
1 parent 4a44c09 commit 93f6557

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

tests/requirements.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
pyright
21
matplotlib
3-
pytest
42
mypy==0.950
3+
pyright
4+
pytest
5+
scikit-learn
6+
scipy
57
typing_extensions==4.2.0

tests/sklearn/preprocessing_tests.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# pyright: reportUnknownVariableType=false
2+
# pyright: reportMissingTypeStubs=false
3+
4+
from typing import Any, assert_type
5+
from numpy import ndarray
6+
from sklearn.preprocessing import normalize
7+
8+
from scipy.sparse._matrix import spmatrix
9+
from scipy.sparse._csr import csr_matrix
10+
11+
12+
# normalize with matrix
13+
matrix: spmatrix = spmatrix()
14+
result = normalize(matrix)
15+
assert_type(result, csr_matrix)
16+
17+
result = normalize(matrix, return_norm=False)
18+
assert_type(result, csr_matrix)
19+
20+
result = normalize(matrix, return_norm=True)
21+
assert_type(result, tuple[csr_matrix, ndarray[Any, Any]])
22+
23+
# normalize with array
24+
array_like = [1]
25+
result = normalize(array_like)
26+
assert_type(result, ndarray[Any, Any])
27+
28+
result = normalize(array_like, return_norm=False)
29+
assert_type(result, ndarray[Any, Any])
30+
31+
result = normalize(array_like, return_norm=True)
32+
assert_type(result, tuple[ndarray[Any, Any], ndarray[Any, Any]])

0 commit comments

Comments
 (0)