diff --git a/python/cuda_parallel/cuda/parallel/experimental/_caching.py b/python/cuda_parallel/cuda/parallel/experimental/_caching.py index fa9165b7ad8..7794e6a7fae 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/_caching.py +++ b/python/cuda_parallel/cuda/parallel/experimental/_caching.py @@ -48,29 +48,24 @@ class CachableFunction: def __init__(self, func): self._func = func + self._identity = None + + @property + def identity(self): + if self._identity is not None: + return self._identity + self._identity = ( + self._func.__code__.co_code, + self._func.__code__.co_consts, + self._func.__closure__, + ) + return self._identity def __eq__(self, other): - func1, func2 = self._func, other._func - - # return True if the functions compare equal for - # caching purposes, False otherwise - code1 = func1.__code__ - code2 = func2.__code__ - - return ( - code1.co_code == code2.co_code - and code1.co_consts == code2.co_consts - and func1.__closure__ == func2.__closure__ - ) + return self.identity == other.identity def __hash__(self): - return hash( - ( - self._func.__code__.co_code, - self._func.__code__.co_consts, - self._func.__closure__, - ) - ) + return hash(self.identity) def __repr__(self): return str(self._func) diff --git a/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py b/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py index 8ed4174da07..f1cf27c7797 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py +++ b/python/cuda_parallel/cuda/parallel/experimental/iterators/_iterators.py @@ -1,5 +1,6 @@ import ctypes import operator +import uuid from functools import lru_cache from typing import Dict, Callable @@ -18,6 +19,14 @@ _DEVICE_POINTER_BITWIDTH = _DEVICE_POINTER_SIZE * 8 +@lru_cache(maxsize=None) +def _get_abi_suffix(kind: "IteratorKind"): + # given an IteratorKind, return a UUID. The value + # is cached so that the same UUID is always returned + # for a given IteratorKind. + return uuid.uuid4().hex + + @lru_cache(maxsize=256) # TODO: what's a reasonable value? def cached_compile(func, sig, abi_name=None, **kwargs): return cuda.compile(func, sig, abi_info={"abi_name": abi_name}, **kwargs) @@ -60,7 +69,6 @@ def __init__( cvalue: ctypes.c_void_p, numba_type: types.Type, value_type: types.Type, - abi_name: str, ): """ Parameters @@ -72,14 +80,10 @@ def __init__( and dereference functions. value_type The numba type of the value returned by the dereference operation. - abi_name - A unique identifier that will determine the abi_names for the - advance and dereference operations. """ self.cvalue = cvalue self.numba_type = numba_type self.value_type = value_type - self.abi_name = abi_name @property def kind(self): @@ -90,8 +94,8 @@ def kind(self): # needed. @property def ltoirs(self) -> Dict[str, bytes]: - advance_abi_name = self.abi_name + "_advance" - deref_abi_name = self.abi_name + "_dereference" + advance_abi_name = "advance_" + _get_abi_suffix(self.kind) + deref_abi_name = "dereference_" + _get_abi_suffix(self.kind) advance_ltoir, _ = cached_compile( self.__class__.advance, ( @@ -123,18 +127,16 @@ def dereference(state): raise NotImplementedError("Subclasses must override dereference staticmethod") def __hash__(self): - return hash( - (self.cvalue.value, self.numba_type, self.value_type, self.abi_name) - ) + return hash((self.kind, self.cvalue.value, self.numba_type, self.value_type)) def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented return ( - self.cvalue.value == other.cvalue.value + self.kind == other.kind + and self.cvalue.value == other.cvalue.value and self.numba_type == other.numba_type and self.value_type == other.value_type - and self.abi_name == other.abi_name ) @@ -178,12 +180,10 @@ class RawPointer(IteratorBase): def __init__(self, ptr: int, value_type: types.Type): cvalue = ctypes.c_void_p(ptr) numba_type = types.CPointer(types.CPointer(value_type)) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -231,12 +231,10 @@ def __init__(self, ptr: int, ntype: types.Type): cvalue = ctypes.c_void_p(ptr) value_type = ntype numba_type = types.CPointer(types.CPointer(value_type)) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -259,12 +257,10 @@ def __init__(self, value: np.number): value_type = numba.from_dtype(value.dtype) cvalue = to_ctypes(value_type)(value) numba_type = types.CPointer(value_type) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -287,12 +283,10 @@ def __init__(self, value: np.number): value_type = numba.from_dtype(value.dtype) cvalue = to_ctypes(value_type)(value) numba_type = types.CPointer(value_type) - abi_name = f"{self.__class__.__name__}_{str(value_type)}" super().__init__( cvalue=cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @staticmethod @@ -327,27 +321,20 @@ def __init__(self, it: IteratorBase, op: CUDADispatcher): self._it = it self._op = CachableFunction(op.py_func) numba_type = it.numba_type - # TODO: the abi name below isn't unique enough when we have e.g., - # two identically named `op` functions with different - # signatures, bytecodes, and/or closure variables. - op_abi_name = f"{self.__class__.__name__}_{op.py_func.__name__}" - # TODO: it would be nice to not need to compile `op` to get # its return type, but there's nothing in the numba API # to do that (yet), _, op_retty = cached_compile( op, (self._it.value_type,), - abi_name=op_abi_name, + abi_name=f"{op.__name__}_{_get_abi_suffix(self.kind)}", output="ltoir", ) value_type = op_retty - abi_name = f"{self.__class__.__name__}_{it.abi_name}_{op_abi_name}" super().__init__( cvalue=it.cvalue, numba_type=numba_type, value_type=value_type, - abi_name=abi_name, ) @property @@ -363,16 +350,10 @@ def dereference(state): return op(it_dereference(state)) def __hash__(self): - return hash( - ( - self._it, - self._op._func.py_func.__code__.co_code, - self._op._func.py_func.__closure__, - ) - ) + return hash((self._it, self._op.identity)) def __eq__(self, other): - if not isinstance(other, IteratorBase): + if not isinstance(other.kind, TransformIteratorKind): return NotImplemented return self._it == other._it and self._op == other._op diff --git a/python/cuda_parallel/tests/test_iterators.py b/python/cuda_parallel/tests/test_iterators.py index 89ae076cb4c..ef94800df43 100644 --- a/python/cuda_parallel/tests/test_iterators.py +++ b/python/cuda_parallel/tests/test_iterators.py @@ -48,6 +48,7 @@ def test_cache_modified_input_iterator_equality(): assert it1 == it2 assert it1 != it3 + assert it1.kind == it2.kind == it3.kind assert it1.kind != it4.kind @@ -71,6 +72,7 @@ def op3(x): assert it1 == it2 assert it1 != it3 assert it1 == it4 + assert it1.kind == it2.kind == it4.kind ary1 = cp.asarray([0, 1, 2])