Skip to content

Commit

Permalink
Fix error with scalar inputs in CAReduce and squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 19, 2022
1 parent 00e0d80 commit 75b7233
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
6 changes: 5 additions & 1 deletion aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,11 @@ def make_node(self, input):
axis = list(range(inp_dims))

copy_op = any(a < 0 for a in axis)
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=inp_dims)
# scalar inputs are treated as 1D regarding axis in this `Op`
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims))
except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims)

# We can't call self.__class__() as there is a class that
# inherits from CAReduce that doesn't have the same signature
Expand Down
6 changes: 5 additions & 1 deletion aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,11 @@ def squeeze(x, axis=None):
elif not isinstance(axis, Collection):
axis = (axis,)

axis = np.core.numeric.normalize_axis_tuple(axis, ndim=x.ndim)
# scalar inputs are treated as 1D regarding axis in this `Op`
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, x.ndim))
except np.AxisError:
raise np.AxisError(axis, ndim=x.ndim)

return x.dimshuffle([i for i in range(x.ndim) if i not in axis])

Expand Down
12 changes: 12 additions & 0 deletions tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import re
import tracemalloc
from copy import copy

Expand Down Expand Up @@ -638,6 +639,17 @@ def test_repeated_axis(self):
with pytest.raises(ValueError, match="repeated axis"):
self.op(aes.add, axis=(0, 0))(x)

def test_scalar_input(self):
x = scalar("x")

assert self.op(aes.add, axis=(-1,))(x).eval({x: 5}) == 5

with pytest.raises(
np.AxisError,
match=re.escape("axis (-2,) is out of bounds for array of dimension 0"),
):
self.op(aes.add, axis=(-2,))(x)


class TestBitOpReduceGrad:
def setup_method(self):
Expand Down
13 changes: 13 additions & 0 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import numpy as np
import pytest

Expand Down Expand Up @@ -448,6 +450,17 @@ def test_invalid_axis(self):
):
squeeze(variable, axis=1)

def test_scalar_input(self):
x = at.scalar("x")

assert squeeze(x, axis=(0,)).eval({x: 5}) == 5

with pytest.raises(
np.AxisError,
match=re.escape("axis (1,) is out of bounds for array of dimension 0"),
):
squeeze(x, axis=1)


class TestCompress(utt.InferShapeTester):
axis_list = [None, -1, 0, 0, 0, 1]
Expand Down

0 comments on commit 75b7233

Please sign in to comment.