Skip to content

Commit 5d97ffa

Browse files
author
Ricardo Vieira
committed
Implement moveaxis
1 parent 0d69809 commit 5d97ffa

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

aesara/tensor/basic.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from collections.abc import Sequence
1111
from functools import partial
1212
from numbers import Number
13-
from typing import Optional, Tuple, Union
13+
from typing import Optional
14+
from typing import Sequence as TypeSequence
15+
from typing import Tuple, Union
1416
from typing import cast as type_cast
1517

1618
import numpy as np
1719
from numpy.core.multiarray import normalize_axis_index
20+
from numpy.core.numeric import normalize_axis_tuple
1821

1922
import aesara
2023
import aesara.scalar.sharedvar
@@ -3635,6 +3638,51 @@ def swapaxes(y, axis1, axis2):
36353638
return y.dimshuffle(li)
36363639

36373640

3641+
def moveaxis(
3642+
a: Union[np.ndarray, TensorVariable],
3643+
source: Union[int, TypeSequence[int]],
3644+
destination: Union[int, TypeSequence[int]],
3645+
) -> TensorVariable:
3646+
"""Move axes of a TensorVariable to new positions.
3647+
3648+
Other axes remain in their original order.
3649+
3650+
Parameters
3651+
----------
3652+
a
3653+
The TensorVariable whose axes should be reordered.
3654+
source
3655+
Original positions of the axes to move. These must be unique.
3656+
destination
3657+
Destination positions for each of the original axes. These must also be
3658+
unique.
3659+
3660+
Returns
3661+
-------
3662+
result
3663+
TensorVariable with moved axes.
3664+
3665+
"""
3666+
3667+
a = as_tensor_variable(a)
3668+
3669+
source = normalize_axis_tuple(source, a.ndim, "source")
3670+
destination = normalize_axis_tuple(destination, a.ndim, "destination")
3671+
3672+
if len(source) != len(destination):
3673+
raise ValueError(
3674+
"`source` and `destination` arguments must have the same number of elements"
3675+
)
3676+
3677+
order = [n for n in range(a.ndim) if n not in source]
3678+
3679+
for dest, src in sorted(zip(destination, source)):
3680+
order.insert(dest, src)
3681+
3682+
result = a.dimshuffle(order)
3683+
return result
3684+
3685+
36383686
def choose(a, choices, mode="raise"):
36393687
"""
36403688
Construct an array from an index array and a set of arrays to choose from.
@@ -4014,6 +4062,7 @@ def take_along_axis(arr, indices, axis=0):
40144062
"atleast_3d",
40154063
"choose",
40164064
"swapaxes",
4065+
"moveaxis",
40174066
"stacklists",
40184067
"diag",
40194068
"diagonal",

tests/tensor/test_basic.py

+18
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
join,
6161
make_vector,
6262
mgrid,
63+
moveaxis,
6364
nonzero,
6465
nonzero_values,
6566
ogrid,
@@ -3984,6 +3985,23 @@ def test_numpy_compare(self):
39843985
assert np.allclose(n_s, t_s)
39853986

39863987

3988+
def test_moveaxis():
3989+
x = at.zeros((3, 4, 5))
3990+
tuple(moveaxis(x, 0, -1).shape.eval()) == (4, 5, 3)
3991+
tuple(moveaxis(x, -1, 0).shape.eval()) == (5, 3, 4)
3992+
tuple(moveaxis(x, [0, 1], [-1, -2]).shape.eval()) == (5, 4, 3)
3993+
tuple(moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape.eval()) == (5, 4, 3)
3994+
3995+
3996+
def test_moveaxis_error():
3997+
x = at.zeros((3, 4, 5))
3998+
with pytest.raises(
3999+
ValueError,
4000+
match="`source` and `destination` arguments must have the same number of elements",
4001+
):
4002+
moveaxis(x, [0, 1], 0)
4003+
4004+
39874005
class TestChoose(utt.InferShapeTester):
39884006
op = staticmethod(choose)
39894007
op_class = Choose

0 commit comments

Comments
 (0)