diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index 79960e52fa..1cd304f5e5 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -17,6 +17,7 @@ from numba.extending import box from aesara import config +from aesara.compile.builders import OpFromGraph from aesara.compile.ops import DeepCopyOp from aesara.graph.basic import Apply, NoParams from aesara.graph.fg import FunctionGraph @@ -374,6 +375,25 @@ def perform(*inputs): return perform +@numba_funcify.register(OpFromGraph) +def numba_funcify_OpFromGraph(op, node=None, **kwargs): + fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) + + if len(op.fgraph.outputs) == 1: + + @numba_njit + def opfromgraph(*inputs): + return fgraph_fn(*inputs)[0] + + else: + + @numba_njit + def opfromgraph(*inputs): + return fgraph_fn(*inputs) + + return opfromgraph + + @numba_funcify.register(FunctionGraph) def numba_funcify_FunctionGraph( fgraph, diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 96ef53203d..29f07649bf 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -12,6 +12,7 @@ import aesara.tensor as at import aesara.tensor.math as aem from aesara import config, shared +from aesara.compile.builders import OpFromGraph from aesara.compile.function import function from aesara.compile.mode import Mode from aesara.compile.ops import ViewOp @@ -1003,3 +1004,18 @@ def test_scalar_return_value_conversion(): mode=numba_mode, ) assert isinstance(x_fn(1.0), np.ndarray) + + +def test_OpFromGraph(): + x, y, z = at.matrices("xyz") + ofg_1 = OpFromGraph([x, y], [x + y], inline=False) + ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False) + + o1, o2 = ofg_2(y, z) + out = ofg_1(x, o1) + o2 + + xv = np.ones((2, 2), dtype=config.floatX) + yv = np.ones((2, 2), dtype=config.floatX) * 3 + zv = np.ones((2, 2), dtype=config.floatX) * 5 + + compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])