Skip to content

Commit 09257ce

Browse files
committed
BUG: ensure consistent result dtype of count_nonzero
The slowpath using apply_along_axis for size 1 axis did not ensure that the dtype is intp like all other paths. This caused inconsistent dtypes on windows where the default integer type is int32. Closes numpygh-9468
1 parent b19fe29 commit 09257ce

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

numpy/core/numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def count_nonzero(a, axis=None):
445445
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
446446

447447
if axis.size == 1:
448-
return counts
448+
return counts.astype(np.intp, copy=False)
449449
else:
450450
# for subsequent axis numbers, that number decreases
451451
# by one in this new 'counts' array if it was larger

numpy/core/tests/test_numeric.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,10 @@ def test_count_nonzero_axis_all_dtypes(self):
10201020
# either integer or tuple arguments for axis
10211021
msg = "Mismatch for dtype: %s"
10221022

1023+
def assert_equal_w_dt(a, b, err_msg):
1024+
assert_equal(a.dtype, b.dtype, err_msg=err_msg)
1025+
assert_equal(a, b, err_msg=err_msg)
1026+
10231027
for dt in np.typecodes['All']:
10241028
err_msg = msg % (np.dtype(dt).name,)
10251029

@@ -1039,13 +1043,13 @@ def test_count_nonzero_axis_all_dtypes(self):
10391043
m[1, 0] = '1970-01-12'
10401044
m = m.astype(dt)
10411045

1042-
expected = np.array([2, 0, 0])
1043-
assert_equal(np.count_nonzero(m, axis=0),
1044-
expected, err_msg=err_msg)
1046+
expected = np.array([2, 0, 0], dtype=np.intp)
1047+
assert_equal_w_dt(np.count_nonzero(m, axis=0),
1048+
expected, err_msg=err_msg)
10451049

1046-
expected = np.array([1, 1, 0])
1047-
assert_equal(np.count_nonzero(m, axis=1),
1048-
expected, err_msg=err_msg)
1050+
expected = np.array([1, 1, 0], dtype=np.intp)
1051+
assert_equal_w_dt(np.count_nonzero(m, axis=1),
1052+
expected, err_msg=err_msg)
10491053

10501054
expected = np.array(2)
10511055
assert_equal(np.count_nonzero(m, axis=(0, 1)),
@@ -1060,13 +1064,13 @@ def test_count_nonzero_axis_all_dtypes(self):
10601064
# setup is slightly different for this dtype
10611065
m = np.array([np.void(1)] * 6).reshape((2, 3))
10621066

1063-
expected = np.array([0, 0, 0])
1064-
assert_equal(np.count_nonzero(m, axis=0),
1065-
expected, err_msg=err_msg)
1067+
expected = np.array([0, 0, 0], dtype=np.intp)
1068+
assert_equal_w_dt(np.count_nonzero(m, axis=0),
1069+
expected, err_msg=err_msg)
10661070

1067-
expected = np.array([0, 0])
1068-
assert_equal(np.count_nonzero(m, axis=1),
1069-
expected, err_msg=err_msg)
1071+
expected = np.array([0, 0], dtype=np.intp)
1072+
assert_equal_w_dt(np.count_nonzero(m, axis=1),
1073+
expected, err_msg=err_msg)
10701074

10711075
expected = np.array(0)
10721076
assert_equal(np.count_nonzero(m, axis=(0, 1)),

0 commit comments

Comments
 (0)