Skip to content

Commit

Permalink
Inline Numba scalar Ops
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 10, 2021
1 parent f41bbab commit 3acdd78
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aesara/link/numba/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def {scalar_op_fn_name}({input_names}):

@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit
@numba.njit(inline="always")
def switch(condition, x, y):
if condition:
return x
Expand Down Expand Up @@ -411,7 +411,7 @@ def numba_funcify_Add(op, node, **kwargs):

nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba.njit(signature)(nary_add_fn)
return numba.njit(signature, inline="always")(nary_add_fn)


@numba_funcify.register(Mul)
Expand All @@ -421,7 +421,7 @@ def numba_funcify_Mul(op, node, **kwargs):

nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba.njit(signature)(nary_mul_fn)
return numba.njit(signature, inline="always")(nary_mul_fn)


def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
Expand Down

0 comments on commit 3acdd78

Please sign in to comment.