Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use objmode in scipy.special without numba-scipy #1078

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,9 @@ def replace(
f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
)

name = new_var.name
new_var = var.type.filter_variable(new_var, allow_convert=True)
new_var.name = name

if var not in self.variables:
# TODO: Raise an actual exception here.
Expand Down
20 changes: 13 additions & 7 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@

def numba_njit(*args, **kwargs):

kwargs = kwargs.copy()
if "cache" not in kwargs:
kwargs["cache"] = config.numba__cache

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])
return numba.njit(*args[1:], **kwargs)(args[0])

return numba.njit(*args, cache=config.numba__cache, **kwargs)
return numba.njit(*args, **kwargs)


def numba_vectorize(*args, **kwargs):
Expand Down Expand Up @@ -319,10 +323,8 @@ def numba_typify(data, dtype=None, **kwargs):
return data


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""

warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
UserWarning,
Expand Down Expand Up @@ -375,6 +377,12 @@ def perform(*inputs):
return perform


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
return generate_fallback_impl(op, node, storage_map, **kwargs)


@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):

Expand Down Expand Up @@ -506,7 +514,6 @@ def {fn_name}({", ".join(input_names)}):


@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):

Expand All @@ -524,7 +531,6 @@ def numba_funcify_Subtensor(op, node, **kwargs):


@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs):

incsubtensor_def_src = create_index_func(
Expand Down
211 changes: 211 additions & 0 deletions aesara/link/numba/dispatch/cython_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import ctypes
import importlib
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast

import numba
import numpy as np
from numpy.typing import DTypeLike
from scipy import LowLevelCallable


_C_TO_NUMPY: Dict[str, DTypeLike] = {
"bool": np.bool_,
"signed char": np.byte,
"unsigned char": np.ubyte,
"short": np.short,
"unsigned short": np.ushort,
"int": np.intc,
"unsigned int": np.uintc,
"long": np.int_,
"unsigned long": np.uint,
"long long": np.longlong,
"float": np.single,
"double": np.double,
"long double": np.longdouble,
"float complex": np.csingle,
"double complex": np.cdouble,
}


@dataclass
class Signature:
res_dtype: DTypeLike
res_c_type: str
arg_dtypes: List[DTypeLike]
arg_c_types: List[str]
arg_names: List[Optional[str]]

@property
def arg_numba_types(self) -> List[DTypeLike]:
return [numba.from_dtype(dtype) for dtype in self.arg_dtypes]

def can_cast_args(self, args: List[DTypeLike]) -> bool:
ok = True
count = 0
for name, dtype in zip(self.arg_names, self.arg_dtypes):
if name == "__pyx_skip_dispatch":
continue
if len(args) <= count:
raise ValueError("Incorrect number of arguments")
ok &= np.can_cast(args[count], dtype)
count += 1
if count != len(args):
return False
return ok

def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool:
args_ok = self.can_cast_args(arg_dtypes)
if np.issubdtype(restype, np.inexact):
result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind")
# We do not want to provide less accuracy than advertised
result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize
else:
result_ok = np.can_cast(self.res_dtype, restype)
return args_ok and result_ok

@staticmethod
def from_c_types(signature: bytes) -> "Signature":
# Match strings like "double(int, double)"
# and extract the return type and the joined arguments
expr = re.compile(rb"\s*(?P<restype>[\w ]*\w+)\s*\((?P<args>[\w\s,]*)\)")
re_match = re.fullmatch(expr, signature)

if re_match is None:
raise ValueError(f"Invalid signature: {signature.decode()}")

groups = re_match.groupdict()
res_c_type = groups["restype"].decode()
res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type]

raw_args = groups["args"]

decl_expr = re.compile(
rb"\s*(?P<type>((long )|(unsigned )|(signed )|(double )|)"
rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))"
rb"(\s(?P<name>[\w_]*))?\s*"
)

arg_dtypes = []
arg_names: List[Optional[str]] = []
arg_c_types = []
for raw_arg in raw_args.split(b","):
re_match = re.fullmatch(decl_expr, raw_arg)
if re_match is None:
raise ValueError(f"Invalid signature: {signature.decode()}")
groups = re_match.groupdict()
arg_c_type = groups["type"].decode()
try:
arg_dtype = _C_TO_NUMPY[arg_c_type]
except KeyError:
raise ValueError(f"Unknown C type: {arg_c_type}")

arg_c_types.append(arg_c_type)
arg_dtypes.append(arg_dtype)
name = groups["name"]
if not name:
arg_names.append(None)
else:
arg_names.append(name.decode())

return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names)


def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]:
"""Find all available implementations for a fused cython function."""
impls = []
mod = importlib.import_module(func.__module__)

signatures = getattr(func, "__signatures__", None)
if signatures is not None:
# Cython function with __signatures__ should be fused and thus
# indexable
func_map = cast(Mapping, func)
candidates = [func_map[key] for key in signatures]
else:
candidates = [func]
for candidate in candidates:
name = candidate.__name__
capsule = mod.__pyx_capi__[name]
llc = LowLevelCallable(capsule)
try:
signature = Signature.from_c_types(llc.signature.encode())
except KeyError:
continue
impls.append((signature, capsule))
return impls


class _CythonWrapper(numba.types.WrapperAddressProtocol):
def __init__(self, pyfunc, signature, capsule):
self._keep_alive = capsule
get_name = ctypes.pythonapi.PyCapsule_GetName
get_name.restype = ctypes.c_char_p
get_name.argtypes = (ctypes.py_object,)

raw_signature = get_name(capsule)

get_pointer = ctypes.pythonapi.PyCapsule_GetPointer
get_pointer.restype = ctypes.c_void_p
get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p)
self._func_ptr = get_pointer(capsule, raw_signature)

self._signature = signature
self._pyfunc = pyfunc

def signature(self):
return numba.from_dtype(self._signature.res_dtype)(
*self._signature.arg_numba_types
)

def __wrapper_address__(self):
return self._func_ptr

def __call__(self, *args, **kwargs):
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
if self.has_pyx_skip_dispatch():
output = self._pyfunc(*args[:-1], **kwargs)
else:
output = self._pyfunc(*args, **kwargs)
return self._signature.res_dtype(output)

def has_pyx_skip_dispatch(self):
if not self._signature.arg_names:
return False
if any(
name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1]
):
raise ValueError("skip_dispatch parameter must be last")
return self._signature.arg_names[-1] == "__pyx_skip_dispatch"

def numpy_arg_dtypes(self):
return self._signature.arg_dtypes

def numpy_output_dtype(self):
return self._signature.res_dtype


def wrap_cython_function(func, restype, arg_types):
impls = _available_impls(func)
compatible = []
for sig, capsule in impls:
if sig.provides(restype, arg_types):
compatible.append((sig, capsule))

def sort_key(args):
sig, _ = args

# Prefer functions with less inputs bytes
argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes)

# Prefer functions with more exact (integer) arguments
num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes)
return (num_inexact, argsize)

compatible.sort(key=sort_key)

if not compatible:
raise NotImplementedError(f"Could not find a compatible impl of {func}")
sig, capsule = compatible[0]
return _CythonWrapper(func, sig, capsule)
35 changes: 33 additions & 2 deletions aesara/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
OR,
XOR,
Add,
Composite,
IntDiv,
Mean,
Mul,
Expand All @@ -40,6 +41,7 @@
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.type import scalar


@singledispatch
Expand Down Expand Up @@ -162,6 +164,15 @@ def create_vectorize_func(
return elemwise_fn


def normalize_axis(axis, ndim):
if axis < 0:
axis = ndim + axis

if axis < 0 or axis >= ndim:
raise np.AxisError(ndim=ndim, axis=axis)
return axis


def create_axis_reducer(
scalar_op: Op,
identity: Union[np.ndarray, Number],
Expand Down Expand Up @@ -216,6 +227,8 @@ def careduce_axis(x):

"""

axis = normalize_axis(axis, ndim)

reduce_elemwise_fn_name = "careduce_axis"

identity = str(identity)
Expand Down Expand Up @@ -338,6 +351,8 @@ def careduce_maximum(input):
if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)

axes = [normalize_axis(axis, ndim) for axis in axes]

careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
to_reduce = reversed(sorted(axes))
Expand Down Expand Up @@ -407,6 +422,8 @@ def jit_compile_reducer(node, fn, **kwds):


def create_axis_apply_fn(fn, axis, ndim, dtype):
axis = normalize_axis(axis, ndim)

reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)

@numba_basic.numba_njit(boundscheck=False)
Expand All @@ -424,8 +441,17 @@ def axis_apply_fn(x):

@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):

scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
# Creating a new scalar node is more involved and unnecessary
# if the scalar_op is composite, as the fgraph already contains
# all the necessary information.
scalar_node = None
if not isinstance(op.scalar_op, Composite):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)

scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__

Expand Down Expand Up @@ -598,6 +624,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis

axis = normalize_axis(axis, x_at.ndim)

if axis is not None:
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
Expand Down Expand Up @@ -635,6 +663,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)

axis = op.axis
axis = normalize_axis(axis, sm_at.ndim)
if axis is not None:
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
Expand Down Expand Up @@ -665,6 +694,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
axis = normalize_axis(axis, x_at.ndim)

if axis is not None:
reduce_max_py = create_axis_reducer(
Expand Down Expand Up @@ -699,6 +729,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
x_ndim = x_at.ndim
axis = normalize_axis(axis, x_ndim)

if x_ndim == 0:

Expand Down
Loading