Skip to content

Commit

Permalink
Fix dtype conversion for scalar Numba Dot results
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 12, 2021
1 parent 54e19a6 commit 9926e07
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aesara/link/numba/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ def numba_funcify_Dot(op, node, **kwargs):

@numba.njit
def dot(x, y):
return np.dot(inputs_cast(x), inputs_cast(y)).astype(out_dtype)
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)

return dot

Expand Down
5 changes: 5 additions & 0 deletions tests/link/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,11 @@ def test_BroadcastTo(x, shape, exc):
),
None,
),
(
set_test_value(aet.lvector(), np.random.random(size=(2,)).astype(np.int64)),
set_test_value(aet.lvector(), np.random.random(size=(2,)).astype(np.int64)),
None,
),
],
)
def test_Dot(x, y, exc):
Expand Down

0 comments on commit 9926e07

Please sign in to comment.