Skip to content

Commit

Permalink
Disable numba cache for cython functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Seyboldt committed Nov 11, 2022
1 parent 7e6ab72 commit 52e1523
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@

def numba_njit(*args, **kwargs):

kwargs = kwargs.copy()
if "cache" not in kwargs:
kwargs["cache"] = config.numba__cache

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.njit(*args[1:], **kwargs)(args[0])

return numba.njit(*args, cache=config.numba__cache, **kwargs)
return numba.njit(*args, **kwargs)


def numba_vectorize(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
signature = create_numba_signature(node, force_scalar=True)

return numba_basic.numba_njit(
signature, inline="always", fastmath=config.numba__fastmath
signature, inline="always", fastmath=config.numba__fastmath, cache=False,
)(scalar_op_fn)


Expand Down

0 comments on commit 52e1523

Please sign in to comment.