Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse consecutive Elemwise subgraphs with multiple clients #1242

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
15 changes: 15 additions & 0 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4165,6 +4165,21 @@ def init_fgraph(self):
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)

# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)

self.fgraph = fgraph

def __init__(self, inputs, outputs):
Expand Down
754 changes: 435 additions & 319 deletions aesara/tensor/rewriting/elemwise.py

Large diffs are not rendered by default.

61 changes: 0 additions & 61 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
register_uncanonicalize,
register_useless,
)
from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
Expand Down Expand Up @@ -2843,66 +2842,6 @@ def check_input(inputs):
return [ret]


def local_add_mul_fusion(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs.

It is better to fuse add/mul that way then in a Composite node as
this make the inner graph of the Composite smaller. This allow to
put more computation in a Composite before hitting the max
recursion limit when pickling Composite.

"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.Add, aes.Mul)
):
return False

s_op = node.op.scalar_op.__class__
new_inp = []
fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op)
and
# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
):
new_inp.extend(inp.owner.inputs)
fused = True
else:
new_inp.append(inp)

# We can not compare the number of inputs as Mul and Add could have
# 0 or 1 inputs in some corner cases.
if fused:
output = node.op(*new_inp)
copy_stack_trace(node.outputs[0], output)

# Do the recursion here to help lower the number of
# FusionOptimizer iteration.
if output.owner:
output2 = local_add_mul_fusion(fgraph, output.owner)
if output2:
return output2
return [output]


fuse_seqopt.register(
"local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion),
"fast_run",
"fusion",
position=0,
)


def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
Expand Down
9 changes: 7 additions & 2 deletions tests/compile/function/test_pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

import aesara.tensor as at
from aesara.compile import UnusedInputError
from aesara.compile import UnusedInputError, get_mode
from aesara.compile.function import function, pfunc
from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.io import In
Expand Down Expand Up @@ -184,7 +184,12 @@ def test_shared_mutable(self):
bval = np.arange(5)
b.set_value(bval, borrow=True)
bval = data_of(b)
f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN")
f = pfunc(
[],
[b_out],
updates=[(b, (b_out + 3))],
mode=get_mode("FAST_RUN").excluding("fusion"),
)
assert (f() == (np.arange(5) * 2)).all()
# because of the update
assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all()
Expand Down
4 changes: 4 additions & 0 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def f_pow2(x_tm2, x_tm1):
state_val = np.array([1.0, 2.0])

numba_mode = get_mode("NUMBA").including("scan_save_mem")
# multi-output Elemwise not supported in NUMBA
numba_mode = numba_mode.excluding("fusion")
py_mode = Mode("py").including("scan_save_mem")

out_fg = FunctionGraph([init_x, n_steps], [output])
Expand Down Expand Up @@ -406,6 +408,8 @@ def inner_fct(seq, state_old, state_current):
g_outs = grad(out.sum(), [seq, init_x])

numba_mode = get_mode("NUMBA").including("scan_save_mem")
# multi-output Elemwise not supported in NUMBA
numba_mode = numba_mode.excluding("fusion")
py_mode = Mode("py").including("scan_save_mem")

out_fg = FunctionGraph([seq, init_x], g_outs)
Expand Down
11 changes: 11 additions & 0 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def test_many_outputs(self):
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]

def test_identical_outputs(self):
x, y, z = floats("xyz")
e0 = x + y + z
e1 = x + y + z
e2 = x / y
C = Composite([x, y, z], [e0, e1, e2])
c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs)
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 6.0, 0.5]

def test_composite_printing(self):
x, y, z = floats("xyz")
e0 = x + y + z
Expand Down
Loading