Skip to content

Commit

Permalink
better support+test for Array subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 18, 2024
1 parent 8eec073 commit fdb3525
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
37 changes: 13 additions & 24 deletions pyopencl/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DoubleDowncastWarning(UserWarning):


_DOUBLE_DOWNCAST_WARNING = (
"The operation you requested would result in a double-precisision "
"The operation you requested would result in a double-precision "
"quantity according to numpy semantics. Since your device does not "
"support double precision, a single-precision quantity is being returned.")

Expand Down Expand Up @@ -748,19 +748,14 @@ def __getstate__(self):
raise RuntimeError("CL Array instances can only be pickled while "
"queue_for_pickling is active.")

d = {}
d["shape"] = self.shape
d["dtype"] = self.dtype
d["strides"] = self.strides
d["allocator"] = self.allocator
d["nbytes"] = self.nbytes
d["strides"] = self.strides
d["offset"] = self.offset
d["data"] = self.get(queue=queue)
d["_flags"] = self._flags
d["size"] = self.size
state = self.__dict__.copy()
del state["context"]
del state["events"]
del state["queue"]
del state["base_data"]
state["data"] = self.get(queue=queue)

return d
return state

def __setstate__(self, state):
try:
Expand All @@ -772,20 +767,14 @@ def __setstate__(self, state):
raise RuntimeError("CL Array instances can only be pickled while "
"queue_for_pickling is active.")

self.queue = queue
self.shape = state["shape"]
self.dtype = state["dtype"]
self.strides = state["strides"]
self.allocator = state["allocator"]
self.offset = state["offset"]
self._flags = state["_flags"]
self.size = state["size"]
self.nbytes = state["nbytes"]
self.__dict__.update(state)
self.context = queue.context
self.events = []
self.queue = queue

if self.allocator is None:
context = queue.context
self.base_data = cl.Buffer(context, cl.mem_flags.READ_WRITE, self.nbytes)
self.base_data = cl.Buffer(self.context, cl.mem_flags.READ_WRITE,
self.nbytes)
else:
self.base_data = self.allocator(self.nbytes)

Expand Down
21 changes: 20 additions & 1 deletion test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,15 @@ def test_xdg_cache_home(ctx_factory):

# {{{ test pickling

from pytools.tag import Taggable


class TaggableCLArray(cl_array.Array, Taggable):
def __init__(self, cq, shape, dtype, tags):
super().__init__(cq=cq, shape=shape, dtype=dtype)
self.tags = tags


def test_array_pickling(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)
Expand All @@ -2406,10 +2415,20 @@ def test_array_pickling(ctx_factory):
with pytest.raises(RuntimeError):
pickle.dumps(a_gpu)

with cl.array.queue_for_pickling(queue):
with cl_array.queue_for_pickling(queue):
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
assert np.all(a_gpu_pickled.get() == a)

a_gpu_tagged = TaggableCLArray(queue, a.shape, a.dtype, tags={"foo", "bar"})
a_gpu_tagged.set(a)

with cl_array.queue_for_pickling(queue):
a_gpu_tagged_pickled = pickle.loads(pickle.dumps(a_gpu_tagged))

assert np.all(a_gpu_tagged_pickled.get() == a)
assert a_gpu_tagged_pickled.tags == a_gpu_tagged.tags


# }}}


Expand Down

0 comments on commit fdb3525

Please sign in to comment.