diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index a36c8ef3d3..8228c43a28 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -1304,7 +1304,11 @@ def make_node(self, input): axis = list(range(inp_dims)) copy_op = any(a < 0 for a in axis) - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=inp_dims) + # scalar inputs are treated as 1D regarding axis in this `Op` + try: + axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims)) + except np.AxisError: + raise np.AxisError(axis, ndim=inp_dims) # We can't call self.__class__() as there is a class that # inherits from CAReduce that doesn't have the same signature diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index 0dc310fb62..2cf92c9216 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -640,7 +640,11 @@ def squeeze(x, axis=None): elif not isinstance(axis, Collection): axis = (axis,) - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=x.ndim) + # scalar inputs are treated as 1D regarding axis in this `Op` + try: + axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, x.ndim)) + except np.AxisError: + raise np.AxisError(axis, ndim=x.ndim) return x.dimshuffle([i for i in range(x.ndim) if i not in axis]) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index cdc7a4c624..56a7beafee 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -1,4 +1,5 @@ import math +import re import tracemalloc from copy import copy @@ -638,6 +639,17 @@ def test_repeated_axis(self): with pytest.raises(ValueError, match="repeated axis"): self.op(aes.add, axis=(0, 0))(x) + def test_scalar_input(self): + x = scalar("x") + + assert self.op(aes.add, axis=(-1,))(x).eval({x: 5}) == 5 + + with pytest.raises( + np.AxisError, + match=re.escape("axis (-2,) is out of bounds for array of dimension 0"), + ): + self.op(aes.add, axis=(-2,))(x) + class TestBitOpReduceGrad: def setup_method(self): diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index db5b53af38..108aba6ed3 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -448,6 +450,17 @@ def test_invalid_axis(self): ): squeeze(variable, axis=1) + def test_scalar_input(self): + x = at.scalar("x") + + assert squeeze(x, axis=(0,)).eval({x: 5}) == 5 + + with pytest.raises( + np.AxisError, + match=re.escape("axis (1,) is out of bounds for array of dimension 0"), + ): + squeeze(x, axis=1) + class TestCompress(utt.InferShapeTester): axis_list = [None, -1, 0, 0, 0, 1]