Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worker: use pickle instead of pyon for ipc #1675

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions artiq/master/dask_serialize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Serialisation / deserialization functions extracted from the Dask library

Dask is a powerful parallel computing library (licensed under BSD) and, as such,
has capable serialization features. However, it's also quite a heavy dependancy
(pulling in lots of others) so this subpackage contains only the
(de)serialization code, with a few minor tweaks to decouple it from the rest of
the dask library.

See issue #1674 for more information about this decision.
"""


from .serialize import serialize, deserialize
186 changes: 186 additions & 0 deletions artiq/master/dask_serialize/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Record known compressors

Includes utilities for determining whether or not to compress
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds weird to compress data for inter-process communication. IPC is not supposed to be slow and compressing data to work around IPC slowness sounds like a band-aid.

"""
from __future__ import annotations

import logging
import random
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING


blosc = False


def ensure_bytes(s):
"""Attempt to turn `s` into bytes.

Parameters
----------
s : Any
The object to be converted. Will correctly handled

* str
* bytes
* objects implementing the buffer protocol (memoryview, ndarray, etc.)

Returns
-------
b : bytes

Raises
------
TypeError
When `s` cannot be converted

Examples
--------
>>> ensure_bytes('123')
b'123'
>>> ensure_bytes(b'123')
b'123'
"""
if isinstance(s, bytes):
return s
elif hasattr(s, "encode"):
return s.encode()
else:
try:
return bytes(s)
except Exception as e:
raise TypeError(
"Object %s is neither a bytes object nor has an encode method" % s
) from e


identity = lambda x: x

if TYPE_CHECKING:
from typing_extensions import Literal

compressions: dict[
str | None | Literal[False],
dict[Literal["compress", "decompress"], Callable[[bytes], bytes]],
] = {None: {"compress": identity, "decompress": identity}}

compressions[False] = compressions[None] # alias


default_compression = None


logger = logging.getLogger(__name__)


with suppress(ImportError):
import zlib

compressions["zlib"] = {"compress": zlib.compress, "decompress": zlib.decompress}


def get_default_compression():
default = "auto"
if default != "auto":
if default in compressions:
return default
else:
raise ValueError(
"Default compression '%s' not found.\n"
"Choices include auto, %s"
% (default, ", ".join(sorted(map(str, compressions))))
)
else:
return default_compression


get_default_compression()


def byte_sample(b, size, n):
"""Sample a bytestring from many locations

Parameters
----------
b : bytes or memoryview
size : int
size of each sample to collect
n : int
number of samples to collect
"""
starts = [random.randint(0, len(b) - size) for j in range(n)]
ends = []
for i, start in enumerate(starts[:-1]):
ends.append(min(start + size, starts[i + 1]))
ends.append(starts[-1] + size)

parts = [b[start:end] for start, end in zip(starts, ends)]
return b"".join(map(ensure_bytes, parts))


def maybe_compress(
payload,
min_size=1e4,
sample_size=1e4,
nsamples=5,
compression="auto",
):
"""
Maybe compress payload

1. We don't compress small messages
2. We sample the payload in a few spots, compress that, and if it doesn't
do any good we return the original
3. We then compress the full original, it it doesn't compress well then we
return the original
4. We return the compressed result
"""
if compression == "auto":
compression = default_compression

if not compression:
return None, payload
if len(payload) < min_size:
return None, payload
if len(payload) > 2 ** 31: # Too large, compression libraries often fail
return None, payload

min_size = int(min_size)
sample_size = int(sample_size)

compress = compressions[compression]["compress"]

# Compress a sample, return original if not very compressed
sample = byte_sample(payload, sample_size, nsamples)
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible
return None, payload

if type(payload) is memoryview:
nbytes = payload.itemsize * len(payload)
else:
nbytes = len(payload)

if default_compression and blosc and type(payload) is memoryview:
# Blosc does itemsize-aware shuffling, resulting in better compression
compressed = blosc.compress(
payload, typesize=payload.itemsize, cname="lz4", clevel=5
)
compression = "blosc"
else:
compressed = compress(ensure_bytes(payload))

if len(compressed) > 0.9 * nbytes: # full data not very compressible
return None, payload
else:
return compression, compressed


def decompress(header, frames):
"""Decompress frames according to information in the header"""
return [
compressions[c]["decompress"](frame)
for c, frame in zip(header["compression"], frames)
]
92 changes: 92 additions & 0 deletions artiq/master/dask_serialize/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
import sys

# import cloudpickle

if sys.version_info < (3, 8):
try:
import pickle5 as pickle
except ImportError:
import pickle
else:
import pickle


HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL

logger = logging.getLogger(__name__)


def _always_use_pickle_for(x):
mod, _, _ = x.__class__.__module__.partition(".")
if mod == "numpy":
import numpy as np

return isinstance(x, np.ndarray)
elif mod == "pandas":
import pandas as pd

return isinstance(x, pd.core.generic.NDFrame)
elif mod == "builtins":
return isinstance(x, (str, bytes))
else:
return False


def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
"""Manage between cloudpickle and pickle

1. Try pickle
2. If it is short then check if it contains __main__
3. If it is long, then first check type, then check __main__
"""
buffers = []
dump_kwargs = {"protocol": protocol or HIGHEST_PROTOCOL}
if dump_kwargs["protocol"] >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append

# CFAB notes: I'm removing all this to avoid the cloudpickle dependancy
# cloudpickle is primarily for serializing classes and functions,
# particularly dynamically defined ones. We don't need that for ARTIQ's IPC,
# so can safely ignore it.
# try:
# buffers.clear()
# result = pickle.dumps(x, **dump_kwargs)
# if len(result) < 1000:
# if b"__main__" in result:
# buffers.clear()
# result = cloudpickle.dumps(x, **dump_kwargs)
# elif not _always_use_pickle_for(x) and b"__main__" in result:
# buffers.clear()
# result = cloudpickle.dumps(x, **dump_kwargs)
# except Exception:
# try:
# buffers.clear()
# result = cloudpickle.dumps(x, **dump_kwargs)
# except Exception as e:
# logger.info("Failed to serialize %s. Exception: %s", x, e)
# raise

# Simplified version here:
try:
buffers.clear()
result = pickle.dumps(x, **dump_kwargs)
except Exception as e:
logger.info("Failed to serialize %s. Exception: %s", x, e)
raise

if buffer_callback is not None:
for b in buffers:
buffer_callback(b)
return result


def loads(x, *, buffers=()):
try:
if buffers:
return pickle.loads(x, buffers=buffers)
else:
return pickle.loads(x)
except Exception:
logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
raise
Loading