Skip to content

Commit d6d7278

Browse files
authored
Merge pull request numpy#9638 from juliantaylor/nonzero-dtype
BUG: ensure consistent result dtype of count_nonzero
2 parents 2afa142 + 09257ce commit d6d7278

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

numpy/core/numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def count_nonzero(a, axis=None):
444444
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
445445

446446
if axis.size == 1:
447-
return counts
447+
return counts.astype(np.intp, copy=False)
448448
else:
449449
# for subsequent axis numbers, that number decreases
450450
# by one in this new 'counts' array if it was larger

numpy/core/tests/test_numeric.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,10 @@ def test_count_nonzero_axis_all_dtypes(self):
10261026
# either integer or tuple arguments for axis
10271027
msg = "Mismatch for dtype: %s"
10281028

1029+
def assert_equal_w_dt(a, b, err_msg):
1030+
assert_equal(a.dtype, b.dtype, err_msg=err_msg)
1031+
assert_equal(a, b, err_msg=err_msg)
1032+
10291033
for dt in np.typecodes['All']:
10301034
err_msg = msg % (np.dtype(dt).name,)
10311035

@@ -1045,13 +1049,13 @@ def test_count_nonzero_axis_all_dtypes(self):
10451049
m[1, 0] = '1970-01-12'
10461050
m = m.astype(dt)
10471051

1048-
expected = np.array([2, 0, 0])
1049-
assert_equal(np.count_nonzero(m, axis=0),
1050-
expected, err_msg=err_msg)
1052+
expected = np.array([2, 0, 0], dtype=np.intp)
1053+
assert_equal_w_dt(np.count_nonzero(m, axis=0),
1054+
expected, err_msg=err_msg)
10511055

1052-
expected = np.array([1, 1, 0])
1053-
assert_equal(np.count_nonzero(m, axis=1),
1054-
expected, err_msg=err_msg)
1056+
expected = np.array([1, 1, 0], dtype=np.intp)
1057+
assert_equal_w_dt(np.count_nonzero(m, axis=1),
1058+
expected, err_msg=err_msg)
10551059

10561060
expected = np.array(2)
10571061
assert_equal(np.count_nonzero(m, axis=(0, 1)),
@@ -1066,13 +1070,13 @@ def test_count_nonzero_axis_all_dtypes(self):
10661070
# setup is slightly different for this dtype
10671071
m = np.array([np.void(1)] * 6).reshape((2, 3))
10681072

1069-
expected = np.array([0, 0, 0])
1070-
assert_equal(np.count_nonzero(m, axis=0),
1071-
expected, err_msg=err_msg)
1073+
expected = np.array([0, 0, 0], dtype=np.intp)
1074+
assert_equal_w_dt(np.count_nonzero(m, axis=0),
1075+
expected, err_msg=err_msg)
10721076

1073-
expected = np.array([0, 0])
1074-
assert_equal(np.count_nonzero(m, axis=1),
1075-
expected, err_msg=err_msg)
1077+
expected = np.array([0, 0], dtype=np.intp)
1078+
assert_equal_w_dt(np.count_nonzero(m, axis=1),
1079+
expected, err_msg=err_msg)
10761080

10771081
expected = np.array(0)
10781082
assert_equal(np.count_nonzero(m, axis=(0, 1)),

0 commit comments

Comments
 (0)