Skip to content

Commit 32cb5f4

Browse files
ricardoV94Brandon T. Willardpurna135Sayam Kumarkc611
committed
Implement Blockwise Op to vectorize existing Ops
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <[email protected]> Co-authored-by: Purna Chandra Mansingh <[email protected]> Co-authored-by: Sayam Kumar <[email protected]> Co-authored-by: Kaustubh <[email protected]>
1 parent b9cbdff commit 32cb5f4

File tree

8 files changed

+872
-17
lines changed

8 files changed

+872
-17
lines changed

pytensor/tensor/blockwise.py

+430
Large diffs are not rendered by default.

pytensor/tensor/elemwise.py

+38-17
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from pytensor.tensor import _get_vector_length, as_tensor_variable
2323
from pytensor.tensor import elemwise_cgen as cgen
2424
from pytensor.tensor import get_vector_length
25+
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
2526
from pytensor.tensor.type import (
2627
TensorType,
2728
continuous_dtypes,
2829
discrete_dtypes,
2930
float_dtypes,
3031
lvector,
3132
)
33+
from pytensor.tensor.utils import import_func_from_string
3234
from pytensor.tensor.var import TensorVariable
3335
from pytensor.utils import uniq
3436

@@ -228,7 +230,7 @@ def __str__(self):
228230
return f"Transpose{{axes={self.shuffle}}}"
229231
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
230232

231-
def perform(self, node, inp, out, params):
233+
def perform(self, node, inp, out, params=None):
232234
(res,) = inp
233235
(storage,) = out
234236

@@ -662,22 +664,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
662664
impl = "c"
663665

664666
if getattr(self, "nfunc_spec", None) and impl != "c":
665-
self.nfunc = getattr(np, self.nfunc_spec[0], None)
666-
if self.nfunc is None:
667-
# Not inside NumPy. So probably another package like scipy.
668-
symb = self.nfunc_spec[0].split(".")
669-
for idx in range(1, len(self.nfunc_spec[0])):
670-
try:
671-
module = __import__(".".join(symb[:idx]))
672-
except ImportError:
673-
break
674-
for sub in symb[1:]:
675-
try:
676-
module = getattr(module, sub)
677-
except AttributeError:
678-
module = None
679-
break
680-
self.nfunc = module
667+
self.nfunc = import_func_from_string(self.nfunc_spec[0])
681668

682669
if (
683670
(len(node.inputs) + len(node.outputs)) <= 32
@@ -1759,3 +1746,37 @@ def _get_vector_length_Elemwise(op, var):
17591746
return get_vector_length(var.owner.inputs[0])
17601747

17611748
raise ValueError(f"Length of {var} cannot be determined")
1749+
1750+
1751+
_vectorize_node.register(Elemwise, vectorize_not_needed)
1752+
1753+
1754+
@_vectorize_node.register(DimShuffle)
1755+
def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply:
1756+
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1757+
if not batched_ndims:
1758+
return node.op.make_node(x)
1759+
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
1760+
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1761+
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1762+
new_order = list(range(batched_ndims)) + [
1763+
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
1764+
]
1765+
return DimShuffle(input_broadcastable, new_order).make_node(x)
1766+
1767+
1768+
@_vectorize_node.register(CAReduce)
1769+
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply:
1770+
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1771+
if not batched_ndims:
1772+
return node.op.make_node(x)
1773+
axes = op.axis
1774+
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
1775+
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
1776+
if axes is None:
1777+
axes = list(range(node.inputs[0].type.ndim))
1778+
else:
1779+
axes = list(axes)
1780+
new_axes = [axis + batched_ndims for axis in axes]
1781+
new_op = op.clone(axis=new_axes)
1782+
return new_op.make_node(x)

pytensor/tensor/random/op.py

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_vector_length,
1717
infer_static_shape,
1818
)
19+
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
1920
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
2021
from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes
2122
from pytensor.tensor.shape import shape_tuple
@@ -428,3 +429,8 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):
428429

429430

430431
default_rng = DefaultGeneratorMakerOp()
432+
433+
434+
# RandomVariables are vectorized on the parameters by default.
435+
# RNG, size and dtype can't be vectorized, but the Op will raise if the wrong input type is passed
436+
_vectorize_node.register(RandomVariable, vectorize_not_needed)

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
44
import pytensor.tensor.rewriting.blas_scipy
5+
import pytensor.tensor.rewriting.blockwise
56
import pytensor.tensor.rewriting.elemwise
67
import pytensor.tensor.rewriting.extra_ops
78

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from pytensor.compile.mode import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
4+
from pytensor.tensor.blockwise import Blockwise, vectorize_node
5+
from pytensor.tensor.rewriting.basic import register_useless
6+
7+
8+
@register_useless("fast_compile")
9+
@node_rewriter([Blockwise])
10+
def local_useless_blockwise(fgraph, node):
11+
# If there is a dispatch implementation that does not require Blockwise, use that instead.
12+
# This means a user created a Blockwise manually when there was no need.
13+
op = node.op
14+
inputs = node.inputs
15+
dummy_core_node = op._create_dummy_core_node(node.inputs)
16+
vect_node = vectorize_node(dummy_core_node, *inputs)
17+
if not isinstance(vect_node.op, Blockwise):
18+
return copy_stack_trace(node.outputs, vect_node.outputs)
19+
20+
21+
@node_rewriter([Blockwise])
22+
def local_useless_unbatched_blockwise(fgraph, node):
23+
"""Remove Blockwise that don't have any batched dims."""
24+
op = node.op
25+
inputs = node.inputs
26+
27+
if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0:
28+
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)
29+
30+
31+
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
32+
optdb.register(
33+
"local_useless_unbatched_blockwise",
34+
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
35+
"fast_run",
36+
"fast_compile",
37+
"blockwise",
38+
position=49,
39+
)

pytensor/tensor/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,27 @@ def as_list(x):
107107
return list(x)
108108
except TypeError:
109109
return [x]
110+
111+
112+
def import_func_from_string(func_string: str): # -> Optional[Callable]:
113+
func = getattr(np, func_string, None)
114+
if func is not None:
115+
return func
116+
117+
# Not inside NumPy or Scipy. So probably another package like scipy.
118+
module = None
119+
items = func_string.split(".")
120+
for idx in range(1, len(items)):
121+
try:
122+
module = __import__(".".join(items[:idx]))
123+
except ImportError:
124+
break
125+
126+
if module:
127+
for sub in items[1:]:
128+
try:
129+
module = getattr(module, sub)
130+
except AttributeError:
131+
module = None
132+
break
133+
return module
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pytensor import function
2+
from pytensor.scalar import log as scalar_log
3+
from pytensor.tensor import matrix, tensor3
4+
from pytensor.tensor.blockwise import Blockwise
5+
from pytensor.tensor.elemwise import Elemwise
6+
from pytensor.tensor.nlinalg import MatrixPinv
7+
8+
9+
def test_useless_blockwise_of_elemwise():
10+
x = matrix("x")
11+
out = Blockwise(Elemwise(scalar_log), signature="()->()")(x)
12+
13+
assert isinstance(out.owner.op, Blockwise)
14+
assert isinstance(out.owner.op.core_op, Elemwise)
15+
16+
fn = function([x], out, mode="FAST_COMPILE")
17+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Elemwise)
18+
19+
20+
def test_useless_unbatched_blockwise():
21+
x = matrix("x")
22+
blockwise_op = Blockwise(MatrixPinv(hermitian=False), signature="(m,n)->(n,m)")
23+
out = blockwise_op(x)
24+
25+
assert isinstance(out.owner.op, Blockwise)
26+
assert isinstance(out.owner.op.core_op, MatrixPinv)
27+
28+
fn = function([x], out, mode="FAST_COMPILE")
29+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv)
30+
31+
# Test that it's not removed when there are batched dims
32+
x = tensor3("x")
33+
out = blockwise_op(x)
34+
fn = function([x], out, mode="FAST_COMPILE")
35+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
36+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)

0 commit comments

Comments
 (0)