|
22 | 22 | from pytensor.tensor import _get_vector_length, as_tensor_variable
|
23 | 23 | from pytensor.tensor import elemwise_cgen as cgen
|
24 | 24 | from pytensor.tensor import get_vector_length
|
| 25 | +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed |
25 | 26 | from pytensor.tensor.type import (
|
26 | 27 | TensorType,
|
27 | 28 | continuous_dtypes,
|
28 | 29 | discrete_dtypes,
|
29 | 30 | float_dtypes,
|
30 | 31 | lvector,
|
31 | 32 | )
|
| 33 | +from pytensor.tensor.utils import import_func_from_string |
32 | 34 | from pytensor.tensor.var import TensorVariable
|
33 | 35 | from pytensor.utils import uniq
|
34 | 36 |
|
@@ -228,7 +230,7 @@ def __str__(self):
|
228 | 230 | return f"Transpose{{axes={self.shuffle}}}"
|
229 | 231 | return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
|
230 | 232 |
|
231 |
| - def perform(self, node, inp, out, params): |
| 233 | + def perform(self, node, inp, out, params=None): |
232 | 234 | (res,) = inp
|
233 | 235 | (storage,) = out
|
234 | 236 |
|
@@ -662,22 +664,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
|
662 | 664 | impl = "c"
|
663 | 665 |
|
664 | 666 | if getattr(self, "nfunc_spec", None) and impl != "c":
|
665 |
| - self.nfunc = getattr(np, self.nfunc_spec[0], None) |
666 |
| - if self.nfunc is None: |
667 |
| - # Not inside NumPy. So probably another package like scipy. |
668 |
| - symb = self.nfunc_spec[0].split(".") |
669 |
| - for idx in range(1, len(self.nfunc_spec[0])): |
670 |
| - try: |
671 |
| - module = __import__(".".join(symb[:idx])) |
672 |
| - except ImportError: |
673 |
| - break |
674 |
| - for sub in symb[1:]: |
675 |
| - try: |
676 |
| - module = getattr(module, sub) |
677 |
| - except AttributeError: |
678 |
| - module = None |
679 |
| - break |
680 |
| - self.nfunc = module |
| 667 | + self.nfunc = import_func_from_string(self.nfunc_spec[0]) |
681 | 668 |
|
682 | 669 | if (
|
683 | 670 | (len(node.inputs) + len(node.outputs)) <= 32
|
@@ -1759,3 +1746,37 @@ def _get_vector_length_Elemwise(op, var):
|
1759 | 1746 | return get_vector_length(var.owner.inputs[0])
|
1760 | 1747 |
|
1761 | 1748 | raise ValueError(f"Length of {var} cannot be determined")
|
| 1749 | + |
| 1750 | + |
| 1751 | +_vectorize_node.register(Elemwise, vectorize_not_needed) |
| 1752 | + |
| 1753 | + |
| 1754 | +@_vectorize_node.register(DimShuffle) |
| 1755 | +def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply: |
| 1756 | + batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
| 1757 | + if not batched_ndims: |
| 1758 | + return node.op.make_node(x) |
| 1759 | + input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable |
| 1760 | + # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) |
| 1761 | + # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) |
| 1762 | + new_order = list(range(batched_ndims)) + [ |
| 1763 | + "x" if (o == "x") else (o + batched_ndims) for o in op.new_order |
| 1764 | + ] |
| 1765 | + return DimShuffle(input_broadcastable, new_order).make_node(x) |
| 1766 | + |
| 1767 | + |
| 1768 | +@_vectorize_node.register(CAReduce) |
| 1769 | +def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: |
| 1770 | + batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
| 1771 | + if not batched_ndims: |
| 1772 | + return node.op.make_node(x) |
| 1773 | + axes = op.axis |
| 1774 | + # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) |
| 1775 | + # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) |
| 1776 | + if axes is None: |
| 1777 | + axes = list(range(node.inputs[0].type.ndim)) |
| 1778 | + else: |
| 1779 | + axes = list(axes) |
| 1780 | + new_axes = [axis + batched_ndims for axis in axes] |
| 1781 | + new_op = op.clone(axis=new_axes) |
| 1782 | + return new_op.make_node(x) |
0 commit comments