Skip to content

Commit

Permalink
Add pickling support for arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 16, 2024
1 parent 5085bb3 commit b988e6a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
88 changes: 88 additions & 0 deletions pyopencl/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,37 @@ class _copy_queue: # noqa: N801
_NOT_PRESENT = object()


# {{{ pickling support

import threading
from contextlib import contextmanager


_QUEUE_FOR_PICKLING_TLS = threading.local()


@contextmanager
def queue_for_pickling(queue):
r"""A context manager that, for the current thread, sets the command queue
to be used for pickling and unpickling :class:`Array`\ s to *queue*."""
try:
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
existing_pickle_queue = None

if existing_pickle_queue is not None:
raise RuntimeError("array_context_for_pickling should not be called "
"inside the context of its own invocation.")

_QUEUE_FOR_PICKLING_TLS.queue = queue
try:
yield None
finally:
_QUEUE_FOR_PICKLING_TLS.queue = None

# }}}


class Array:
"""A :class:`numpy.ndarray` work-alike that stores its data and performs
its computations on the compute device. :attr:`shape` and :attr:`dtype` work
Expand Down Expand Up @@ -705,6 +736,63 @@ def __init__(
"than expected, potentially leading to crashes.",
InconsistentOpenCLQueueWarning, stacklevel=2)

# {{{ Pickling

def __getstate__(self):
try:
queue = _QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
queue = None

if queue is None:
raise RuntimeError("CL Array instances can only be pickled while "
"queue_for_pickling is active.")

d = {}
d["shape"] = self.shape
d["dtype"] = self.dtype
d["strides"] = self.strides
d["allocator"] = self.allocator
d["nbytes"] = self.nbytes
d["strides"] = self.strides
d["offset"] = self.offset
d["data"] = self.get(queue=queue)
d["_flags"] = self._flags
d["size"] = self.size

return d

def __setstate__(self, state):
try:
queue = _QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
queue = None

if queue is None:
raise RuntimeError("CL Array instances can only be pickled while "
"queue_for_pickling is active.")

self.queue = queue
self.shape = state["shape"]
self.dtype = state["dtype"]
self.strides = state["strides"]
self.allocator = state["allocator"]
self.offset = state["offset"]
self._flags = state["_flags"]
self.size = state["size"]
self.nbytes = state["nbytes"]
self.events = []

if self.allocator is None:
context = queue.context
self.base_data = cl.Buffer(context, cl.mem_flags.READ_WRITE, self.nbytes)
else:
self.base_data = self.allocator(self.nbytes)

self.set(state["data"], queue=queue)

# }}}

@property
def ndim(self):
return len(self.shape)
Expand Down
20 changes: 20 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,26 @@ def test_xdg_cache_home(ctx_factory):
# }}}


# {{{ test pickling

def test_array_pickling(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)

a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
a_gpu = cl_array.to_device(queue, a)

import pickle
with pytest.raises(RuntimeError):
pickle.dumps(a_gpu)

with cl.array.queue_for_pickling(queue):
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
assert np.all(a_gpu_pickled.get() == a)

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit b988e6a

Please sign in to comment.