-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
worker: use pickle instead of pyon for ipc #1675
Draft
charlesbaynham
wants to merge
9
commits into
m-labs:master
Choose a base branch
from
charlesbaynham:worker_comms_speed
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,614
−82
Draft
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
42aa9ad
Add speed test for worker writing
charlesbaynham 0876d3e
Use Dask instead of pyon
charlesbaynham 8ac16ed
Working on py36
charlesbaynham c6cb44b
Tidy and finish
charlesbaynham 08ed519
Mem optimizations
charlesbaynham 9cf0ccc
Fix on python 3.9
charlesbaynham 03a2e1f
Extract Dask serialization routines
charlesbaynham 80a3371
Remove speculative compression libraries: we don't want behaviour cha…
charlesbaynham 4a71965
Require msgpack
charlesbaynham File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
Serialisation / deserialization functions extracted from the Dask library | ||
|
||
Dask is a powerful parallel computing library (licensed under BSD) and, as such, | ||
has capable serialization features. However, it's also quite a heavy dependancy | ||
(pulling in lots of others) so this subpackage contains only the | ||
(de)serialization code, with a few minor tweaks to decouple it from the rest of | ||
the dask library. | ||
|
||
See issue #1674 for more information about this decision. | ||
""" | ||
|
||
|
||
from .serialize import serialize, deserialize |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
""" | ||
Record known compressors | ||
|
||
Includes utilities for determining whether or not to compress | ||
""" | ||
from __future__ import annotations | ||
|
||
import logging | ||
import random | ||
from collections.abc import Callable | ||
from contextlib import suppress | ||
from functools import partial | ||
from typing import TYPE_CHECKING | ||
|
||
|
||
blosc = False | ||
|
||
|
||
def ensure_bytes(s): | ||
"""Attempt to turn `s` into bytes. | ||
|
||
Parameters | ||
---------- | ||
s : Any | ||
The object to be converted. Will correctly handled | ||
|
||
* str | ||
* bytes | ||
* objects implementing the buffer protocol (memoryview, ndarray, etc.) | ||
|
||
Returns | ||
------- | ||
b : bytes | ||
|
||
Raises | ||
------ | ||
TypeError | ||
When `s` cannot be converted | ||
|
||
Examples | ||
-------- | ||
>>> ensure_bytes('123') | ||
b'123' | ||
>>> ensure_bytes(b'123') | ||
b'123' | ||
""" | ||
if isinstance(s, bytes): | ||
return s | ||
elif hasattr(s, "encode"): | ||
return s.encode() | ||
else: | ||
try: | ||
return bytes(s) | ||
except Exception as e: | ||
raise TypeError( | ||
"Object %s is neither a bytes object nor has an encode method" % s | ||
) from e | ||
|
||
|
||
identity = lambda x: x | ||
|
||
if TYPE_CHECKING: | ||
from typing_extensions import Literal | ||
|
||
compressions: dict[ | ||
str | None | Literal[False], | ||
dict[Literal["compress", "decompress"], Callable[[bytes], bytes]], | ||
] = {None: {"compress": identity, "decompress": identity}} | ||
|
||
compressions[False] = compressions[None] # alias | ||
|
||
|
||
default_compression = None | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
with suppress(ImportError): | ||
import zlib | ||
|
||
compressions["zlib"] = {"compress": zlib.compress, "decompress": zlib.decompress} | ||
|
||
|
||
def get_default_compression(): | ||
default = "auto" | ||
if default != "auto": | ||
if default in compressions: | ||
return default | ||
else: | ||
raise ValueError( | ||
"Default compression '%s' not found.\n" | ||
"Choices include auto, %s" | ||
% (default, ", ".join(sorted(map(str, compressions)))) | ||
) | ||
else: | ||
return default_compression | ||
|
||
|
||
get_default_compression() | ||
|
||
|
||
def byte_sample(b, size, n): | ||
"""Sample a bytestring from many locations | ||
|
||
Parameters | ||
---------- | ||
b : bytes or memoryview | ||
size : int | ||
size of each sample to collect | ||
n : int | ||
number of samples to collect | ||
""" | ||
starts = [random.randint(0, len(b) - size) for j in range(n)] | ||
ends = [] | ||
for i, start in enumerate(starts[:-1]): | ||
ends.append(min(start + size, starts[i + 1])) | ||
ends.append(starts[-1] + size) | ||
|
||
parts = [b[start:end] for start, end in zip(starts, ends)] | ||
return b"".join(map(ensure_bytes, parts)) | ||
|
||
|
||
def maybe_compress( | ||
payload, | ||
min_size=1e4, | ||
sample_size=1e4, | ||
nsamples=5, | ||
compression="auto", | ||
): | ||
""" | ||
Maybe compress payload | ||
|
||
1. We don't compress small messages | ||
2. We sample the payload in a few spots, compress that, and if it doesn't | ||
do any good we return the original | ||
3. We then compress the full original, it it doesn't compress well then we | ||
return the original | ||
4. We return the compressed result | ||
""" | ||
if compression == "auto": | ||
compression = default_compression | ||
|
||
if not compression: | ||
return None, payload | ||
if len(payload) < min_size: | ||
return None, payload | ||
if len(payload) > 2 ** 31: # Too large, compression libraries often fail | ||
return None, payload | ||
|
||
min_size = int(min_size) | ||
sample_size = int(sample_size) | ||
|
||
compress = compressions[compression]["compress"] | ||
|
||
# Compress a sample, return original if not very compressed | ||
sample = byte_sample(payload, sample_size, nsamples) | ||
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible | ||
return None, payload | ||
|
||
if type(payload) is memoryview: | ||
nbytes = payload.itemsize * len(payload) | ||
else: | ||
nbytes = len(payload) | ||
|
||
if default_compression and blosc and type(payload) is memoryview: | ||
# Blosc does itemsize-aware shuffling, resulting in better compression | ||
compressed = blosc.compress( | ||
payload, typesize=payload.itemsize, cname="lz4", clevel=5 | ||
) | ||
compression = "blosc" | ||
else: | ||
compressed = compress(ensure_bytes(payload)) | ||
|
||
if len(compressed) > 0.9 * nbytes: # full data not very compressible | ||
return None, payload | ||
else: | ||
return compression, compressed | ||
|
||
|
||
def decompress(header, frames): | ||
"""Decompress frames according to information in the header""" | ||
return [ | ||
compressions[c]["decompress"](frame) | ||
for c, frame in zip(header["compression"], frames) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import logging | ||
import sys | ||
|
||
# import cloudpickle | ||
|
||
if sys.version_info < (3, 8): | ||
try: | ||
import pickle5 as pickle | ||
except ImportError: | ||
import pickle | ||
else: | ||
import pickle | ||
|
||
|
||
HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _always_use_pickle_for(x): | ||
mod, _, _ = x.__class__.__module__.partition(".") | ||
if mod == "numpy": | ||
import numpy as np | ||
|
||
return isinstance(x, np.ndarray) | ||
elif mod == "pandas": | ||
import pandas as pd | ||
|
||
return isinstance(x, pd.core.generic.NDFrame) | ||
elif mod == "builtins": | ||
return isinstance(x, (str, bytes)) | ||
else: | ||
return False | ||
|
||
|
||
def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): | ||
"""Manage between cloudpickle and pickle | ||
|
||
1. Try pickle | ||
2. If it is short then check if it contains __main__ | ||
3. If it is long, then first check type, then check __main__ | ||
""" | ||
buffers = [] | ||
dump_kwargs = {"protocol": protocol or HIGHEST_PROTOCOL} | ||
if dump_kwargs["protocol"] >= 5 and buffer_callback is not None: | ||
dump_kwargs["buffer_callback"] = buffers.append | ||
|
||
# CFAB notes: I'm removing all this to avoid the cloudpickle dependancy | ||
# cloudpickle is primarily for serializing classes and functions, | ||
# particularly dynamically defined ones. We don't need that for ARTIQ's IPC, | ||
# so can safely ignore it. | ||
# try: | ||
# buffers.clear() | ||
# result = pickle.dumps(x, **dump_kwargs) | ||
# if len(result) < 1000: | ||
# if b"__main__" in result: | ||
# buffers.clear() | ||
# result = cloudpickle.dumps(x, **dump_kwargs) | ||
# elif not _always_use_pickle_for(x) and b"__main__" in result: | ||
# buffers.clear() | ||
# result = cloudpickle.dumps(x, **dump_kwargs) | ||
# except Exception: | ||
# try: | ||
# buffers.clear() | ||
# result = cloudpickle.dumps(x, **dump_kwargs) | ||
# except Exception as e: | ||
# logger.info("Failed to serialize %s. Exception: %s", x, e) | ||
# raise | ||
|
||
# Simplified version here: | ||
try: | ||
buffers.clear() | ||
result = pickle.dumps(x, **dump_kwargs) | ||
except Exception as e: | ||
logger.info("Failed to serialize %s. Exception: %s", x, e) | ||
raise | ||
|
||
if buffer_callback is not None: | ||
for b in buffers: | ||
buffer_callback(b) | ||
return result | ||
|
||
|
||
def loads(x, *, buffers=()): | ||
try: | ||
if buffers: | ||
return pickle.loads(x, buffers=buffers) | ||
else: | ||
return pickle.loads(x) | ||
except Exception: | ||
logger.info("Failed to deserialize %s", x[:10000], exc_info=True) | ||
raise |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds weird to compress data for inter-process communication. IPC is not supposed to be slow and compressing data to work around IPC slowness sounds like a band-aid.