Skip to content

Commit

Permalink
Add a Numba OpFromGraph implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 11, 2022
1 parent 308969c commit 643c973
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
20 changes: 20 additions & 0 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numba.extending import box

from aesara import config
from aesara.compile.builders import OpFromGraph
from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply, NoParams
from aesara.graph.fg import FunctionGraph
Expand Down Expand Up @@ -374,6 +375,25 @@ def perform(*inputs):
return perform


@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) == 1:

@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]

else:

@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)

return opfromgraph


@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
Expand Down
16 changes: 16 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import aesara.tensor as at
import aesara.tensor.math as aem
from aesara import config, shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp
Expand Down Expand Up @@ -1003,3 +1004,18 @@ def test_scalar_return_value_conversion():
mode=numba_mode,
)
assert isinstance(x_fn(1.0), np.ndarray)


def test_OpFromGraph():
x, y, z = at.matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])

0 comments on commit 643c973

Please sign in to comment.