Skip to content

Commit

Permalink
Implement Blockwise Op to vectorize existing Ops
Browse files Browse the repository at this point in the history
Inspired by: aesara-devs/aesara#1215

Co-authored-by: Brandon T. Willard <[email protected]>
Co-authored-by: Purna Chandra Mansingh <[email protected]>
Co-authored-by: Sayam Kumar <[email protected]>
Co-authored-by: Kaustubh <[email protected]>
  • Loading branch information
5 people committed May 17, 2023
1 parent 882b418 commit 911fe36
Show file tree
Hide file tree
Showing 3 changed files with 357 additions and 95 deletions.
193 changes: 98 additions & 95 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import re
from collections.abc import Sequence
from functools import singledispatch
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import numpy as np

from pytensor import config
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Constant, Op
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.op import Op
from pytensor.graph.null_type import NullType
from pytensor.tensor import TensorVariable, as_tensor_variable, shape_padleft, tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes
from pytensor.tensor.var import TensorVariable, as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor


# Some Ops are implemented in a way that they already batch natively
# TODO: Make them subclass from Blockwise and get rid of duplicated code
natively_batched_ops = (Elemwise, RandomVariable)
from pytensor.tensor.elemwise import Elemwise, DimShuffle

# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad)

Expand All @@ -29,6 +27,65 @@
_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST)


def safe_signature(
core_inputs: Sequence[TensorVariable],
core_outputs: Sequence[TensorVariable],
) -> str:
def operand_sig(operand: TensorVariable, prefix: str) -> str:
operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim))
return f"({operands})"

inputs_sig = ",".join(
operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs)
)
outputs_sig = ",".join(
operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs)
)
return f"{inputs_sig}->{outputs_sig}"


@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
return Blockwise(op, signature=signature).make_node(*bached_inputs)


def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""

# Special cases for most common `Op`s that don't really need to be "vectorized"
# TODO: Other simple cases include Reshape, Alloc, ?

op = node.op

if isinstance(op, Elemwise):
return op.make_node(*batched_inputs)

if isinstance(op, Blockwise):
return op.make_node(*batched_inputs)

if isinstance(op, DimShuffle):
[x] = batched_inputs
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
new_order = list(range(batched_ndims)) + ["x" if o == "x" else o + batched_ndims for o in op.new_order]
return DimShuffle(input_broadcastable, new_order).make_node(x)

from pytensor.tensor.random.op import RandomVariable
if isinstance(op, RandomVariable):
return op.make_node(*batched_inputs)

# Fallback to dispatch implementation so users can override behavior
return _vectorize_node(op, node, *batched_inputs)


def _parse_gufunc_signature(signature):
"""
Parse string signatures for a generalized universal function.
Expand Down Expand Up @@ -87,13 +144,18 @@ def __init__(
e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication
"""
# Some Ops are implemented in a way that they already batch natively
# TODO: Consider refactoring them into a shared class
from pytensor.tensor.random.op import RandomVariable
natively_batched_ops = (Elemwise, RandomVariable)

if isinstance(core_op, type(self)):
raise TypeError("core_op cannot be a Blockwise")
if isinstance(core_op, natively_batched_ops):
raise TypeError(f"{core_op} already works as a Blockwise")

if signature is None:
signature = getattr(core_op, "signature", None)
signature = getattr(core_op, "gufunc_signature", None)
if signature is None:
raise ValueError(
f"Signature not provided nor found in core_op {core_op}"
Expand All @@ -102,13 +164,13 @@ def __init__(
self.core_op = core_op
self.signature = signature
self.name = name
self._signature = _parse_gufunc_signature(signature)
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self._gufunc = None
super().__init__(**kwargs)

def _create_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = []
for i, (inp, sig) in enumerate(zip(inputs, self._signature[0])):
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
if inp.type.ndim < len(sig):
raise ValueError(
f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
Expand All @@ -122,11 +184,11 @@ def _create_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:

core_node = self.core_op.make_node(*core_input_types)

if len(core_node.outputs) != len(self._signature[1]):
if len(core_node.outputs) != len(self.outputs_sig):
raise ValueError(
f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}"
)
for i, (core_out, sig) in enumerate(zip(core_node.outputs, self._signature[1])):
for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)):
if core_out.type.ndim != len(sig):
raise ValueError(
f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
Expand All @@ -137,14 +199,19 @@ def _create_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
def make_node(self, *inputs):
inputs = [as_tensor_variable(i) for i in inputs]

core_node = self._create_core_node(inputs)
core_node = self._create_dummy_core_node(inputs)

batch_ndims = max(
inp.type.ndim - len(sig) for inp, sig in zip(inputs, self._signature[0])
inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig)
)

# Don't pollute the graph with useless BlockWise
if not batch_ndims:
return self.core_op.make_node(*inputs)

batched_inputs = []
batch_shapes = []
for i, (inp, sig) in enumerate(zip(inputs, self._signature[0])):
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
# Append missing dims to the left
missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig))
if missing_batch_ndims:
Expand Down Expand Up @@ -194,7 +261,7 @@ def get_most_specialized_batch_shape(
return Apply(self, batched_inputs, batched_outputs)

def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int:
return cast(int, outputs[0].type.ndim - len(self._signature[1][0]))
return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0]))

def infer_shape(
self, fgraph, node, input_shapes
Expand All @@ -205,13 +272,13 @@ def infer_shape(
batch_ndims = self._batch_ndim_from_outputs(node.outputs)
core_dims: Dict[str, Any] = {}
batch_shapes = []
for input_shape, sig in zip(input_shapes, self._signature[0]):
for input_shape, sig in zip(input_shapes, self.inputs_sig):
batch_shapes.append(input_shape[:batch_ndims])
core_shape = input_shapes[batch_ndims:]
core_shape = input_shape[batch_ndims:]

for core_dim, dim_name in zip(core_shape, sig):
prev_core_dim = core_dims.get(core_dim)
if not prev_core_dim:
if prev_core_dim is None:
core_dims[dim_name] = core_dim
# Prefer constants
elif not isinstance(prev_core_dim, Constant):
Expand All @@ -220,7 +287,7 @@ def infer_shape(
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)

out_shapes = []
for output, sig in zip(node.outputs, self._signature[1]):
for output, sig in zip(node.outputs, self.outputs_sig):
core_out_shape = []
for i, dim_name in enumerate(sig):
# The output dim is the same as another input dim
Expand All @@ -246,14 +313,14 @@ def as_core(t, core_t):
# Inputs could be NullType or DisconnectedType
if isinstance(t.type, (NullType, DisconnectedType)):
return t
return core_t
return core_t.type()

with config.change_flags(compute_test_value="off"):
safe_inputs = [
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
for inp, sig in zip(inputs, self._signature[0])
for inp, sig in zip(inputs, self.inputs_sig)
]
core_node = self._create_core_node(safe_inputs)
core_node = self._create_dummy_core_node(safe_inputs)

core_inputs = [
as_core(inp, core_inp)
Expand All @@ -267,21 +334,6 @@ def as_core(t, core_t):

core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)

def safe_signature(
core_inputs: Sequence[TensorVariable],
core_outputs: Sequence[TensorVariable],
) -> str:
def operand_sig(operand: TensorVariable, prefix: str) -> str:
operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim))
return f"({operands})"

inputs_sig = ",".join(
operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs)
)
outputs_sig = ",".join(
operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs)
)
return f"{inputs_sig}->{outputs_sig}"

batch_ndims = self._batch_ndim_from_outputs(outputs)

Expand All @@ -305,23 +357,8 @@ def transform(var):
return var

batched_inputs = [transform(inp) for inp in node.inputs]

if not batch_ndims or isinstance(node.op, natively_batched_ops):
batched_var = node.op.make_node(*batched_inputs)
elif node.op == self.core_op:
batched_var = self.make_node(*batched_inputs)
else:
if hasattr(node.op, "signature"):
grad_signature = node.op.signature
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signature to our Ops
grad_signature = safe_signature(node.inputs, node.outputs)
batched_var = Blockwise(node.op, signature=grad_signature).make_node(
*batched_inputs
)

batched_var = batched_var.outputs[var.owner.outputs.index(var)]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]

return batched_var

Expand All @@ -335,40 +372,6 @@ def transform(var):

return ret

def R_op(self, inputs, eval_points):
from pytensor.tensor import ones_like

outs = self(*inputs, return_list=True)
rval = [None for _ in outs]
for idx, out in enumerate(outs):
# make such that _bgrads computes only the gradients of the
# current output on the inputs ( and not all outputs)
ograds = [x.zeros_like() for x in outs]
ograds[idx] = ones_like(out)

bgrads = self._bgrad(inputs, outs, ograds)
rop_out = None

for jdx, (inp, eval_point) in enumerate(zip(inputs, eval_points)):
# if None, then we can just ignore this branch ..
# what we do is to assume that for any non-differentiable
# branch, the gradient is actually 0, which I think is not
# the right thing to do .. have to talk to Ian and James
# about it
if bgrads[jdx] is None or isinstance(
bgrads[jdx].type, DisconnectedType
):
pass
elif eval_point is not None:
if rop_out is None:
rop_out = bgrads[jdx] * eval_point
else:
rop_out = rop_out + bgrads[jdx] * eval_point

rval[idx] = rop_out

return rval

def L_op(self, inputs, outs, ograds):
from pytensor.tensor.math import sum as pt_sum

Expand Down Expand Up @@ -399,7 +402,7 @@ def L_op(self, inputs, outs, ograds):
# Sum out the broadcasted dimensions
batch_ndims = self._batch_ndim_from_outputs(outs)
batch_shape = outs[0].type.shape[:batch_ndims]
for i, (inp, sig) in enumerate(zip(inputs, self._signature[0])):
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue

Expand All @@ -418,8 +421,8 @@ def L_op(self, inputs, outs, ograds):
def _create_gufunc(self, node):
# TODO: Use `impl` numpy versions just like Elemwise and ScalarOps do

n_outs = len(self._signature[0])
core_node = self._create_core_node(node.inputs)
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)

def core_func(*inner_inputs):
inner_outputs = [[None] for _ in range(n_outs)]
Expand Down
8 changes: 8 additions & 0 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

class MatrixPinv(Op):
__props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"

def __init__(self, hermitian):
self.hermitian = hermitian
Expand Down Expand Up @@ -81,6 +82,9 @@ def pinv(x, hermitian=False):
class Inv(Op):
"""Computes the inverse of one or more matrices."""

# TODO: This Op is already natively vectorized, dispatch on `vectorized_node` to avoid useless Blockwise
gufunc_signature = "(m,m)->(m,m)"

def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
Expand Down Expand Up @@ -112,6 +116,7 @@ class MatrixInverse(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->(m,m)"

def __init__(self):
pass
Expand Down Expand Up @@ -200,6 +205,7 @@ class Det(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->()"

def make_node(self, x):
x = as_tensor_variable(x)
Expand Down Expand Up @@ -237,6 +243,7 @@ class SLogDet(Op):
"""

__props__ = ()
gufunc_signature = "(m, m)->(),()"

def make_node(self, x):
x = as_tensor_variable(x)
Expand Down Expand Up @@ -272,6 +279,7 @@ class Eig(Op):

_numop = staticmethod(np.linalg.eig)
__props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"

def make_node(self, x):
x = as_tensor_variable(x)
Expand Down
Loading

0 comments on commit 911fe36

Please sign in to comment.