|
10 | 10 | from collections.abc import Sequence
|
11 | 11 | from functools import partial
|
12 | 12 | from numbers import Number
|
13 |
| -from typing import Optional, Tuple, Union |
| 13 | +from typing import Optional |
| 14 | +from typing import Sequence as TypeSequence |
| 15 | +from typing import Tuple, Union |
14 | 16 | from typing import cast as type_cast
|
15 | 17 |
|
16 | 18 | import numpy as np
|
17 | 19 | from numpy.core.multiarray import normalize_axis_index
|
| 20 | +from numpy.core.numeric import normalize_axis_tuple |
18 | 21 |
|
19 | 22 | import aesara
|
20 | 23 | import aesara.scalar.sharedvar
|
@@ -3635,6 +3638,51 @@ def swapaxes(y, axis1, axis2):
|
3635 | 3638 | return y.dimshuffle(li)
|
3636 | 3639 |
|
3637 | 3640 |
|
| 3641 | +def moveaxis( |
| 3642 | + a: Union[np.ndarray, TensorVariable], |
| 3643 | + source: Union[int, TypeSequence[int]], |
| 3644 | + destination: Union[int, TypeSequence[int]], |
| 3645 | +) -> TensorVariable: |
| 3646 | + """Move axes of a TensorVariable to new positions. |
| 3647 | +
|
| 3648 | + Other axes remain in their original order. |
| 3649 | +
|
| 3650 | + Parameters |
| 3651 | + ---------- |
| 3652 | + a |
| 3653 | + The TensorVariable whose axes should be reordered. |
| 3654 | + source |
| 3655 | + Original positions of the axes to move. These must be unique. |
| 3656 | + destination |
| 3657 | + Destination positions for each of the original axes. These must also be |
| 3658 | + unique. |
| 3659 | +
|
| 3660 | + Returns |
| 3661 | + ------- |
| 3662 | + result |
| 3663 | + TensorVariable with moved axes. |
| 3664 | +
|
| 3665 | + """ |
| 3666 | + |
| 3667 | + a = as_tensor_variable(a) |
| 3668 | + |
| 3669 | + source = normalize_axis_tuple(source, a.ndim, "source") |
| 3670 | + destination = normalize_axis_tuple(destination, a.ndim, "destination") |
| 3671 | + |
| 3672 | + if len(source) != len(destination): |
| 3673 | + raise ValueError( |
| 3674 | + "`source` and `destination` arguments must have the same number of elements" |
| 3675 | + ) |
| 3676 | + |
| 3677 | + order = [n for n in range(a.ndim) if n not in source] |
| 3678 | + |
| 3679 | + for dest, src in sorted(zip(destination, source)): |
| 3680 | + order.insert(dest, src) |
| 3681 | + |
| 3682 | + result = a.dimshuffle(order) |
| 3683 | + return result |
| 3684 | + |
| 3685 | + |
3638 | 3686 | def choose(a, choices, mode="raise"):
|
3639 | 3687 | """
|
3640 | 3688 | Construct an array from an index array and a set of arrays to choose from.
|
@@ -4014,6 +4062,7 @@ def take_along_axis(arr, indices, axis=0):
|
4014 | 4062 | "atleast_3d",
|
4015 | 4063 | "choose",
|
4016 | 4064 | "swapaxes",
|
| 4065 | + "moveaxis", |
4017 | 4066 | "stacklists",
|
4018 | 4067 | "diag",
|
4019 | 4068 | "diagonal",
|
|
0 commit comments