Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform custom deserializations #1

Open
wants to merge 2 commits into
base: v5.2.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions kombu/asynchronous/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@

from kombu.log import get_logger

try:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No change necessary, but to save us the conditional handling of ZoneInfo, as we only want UTC we could use Python's datetime.timezone.utc.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am +1 on this, seems like a less risky change than introducing zoneinfo

Copy link
Author

@jbkkd jbkkd Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that we could do that, however this change is taken directly from kombu:
celery#1680
Seeing that ZoneInfo will be added once we upgrade regardless, we might as well do this now. I prefer sticking to whatever kombu is doing even if it isn't the best solution.

from pytz import utc
except ImportError: # pragma: no cover
utc = None

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo

__all__ = ('Entry', 'Timer', 'to_timestamp')

logger = get_logger(__name__)

DEFAULT_MAX_INTERVAL = 2
EPOCH = datetime.utcfromtimestamp(0).replace(tzinfo=utc)
EPOCH = datetime.utcfromtimestamp(0).replace(tzinfo=ZoneInfo("UTC"))
IS_PYPY = hasattr(sys, 'pypy_version_info')

scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry'))


def to_timestamp(d, default_timezone=utc, time=monotonic):
def to_timestamp(d, default_timezone=ZoneInfo("UTC"), time=monotonic):
"""Convert datetime to timestamp.

If d' is already a timestamp, then that will be used.
Expand Down
74 changes: 63 additions & 11 deletions kombu/utils/json.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""JSON Serialization Utilities."""

import base64
import datetime
import decimal
from decimal import Decimal
import json as stdjson
import uuid
from typing import Any, Callable, TypeVar


try:
from django.utils.functional import Promise as DjangoPromise
Expand Down Expand Up @@ -69,7 +72,19 @@ def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs):
**dict(default_kwargs, **kwargs))


def loads(s, _loads=json.loads, decode_bytes=True):
def object_hook(o: dict):
"""Hook function to perform custom deserialization."""
if o.keys() == {"__type__", "__value__"}:
decoder = _decoders.get(o["__type__"])
if decoder:
return decoder(o["__value__"])
else:
raise ValueError("Unsupported type", type, o)
else:
return o


def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
"""Deserialize json from string."""
# None of the json implementations supports decoding from
# a buffer/memoryview, or even reading from a stream
Expand All @@ -78,14 +93,51 @@ def loads(s, _loads=json.loads, decode_bytes=True):
# over. Note that pickle does support buffer/memoryview
# </rant>
if isinstance(s, memoryview):
s = s.tobytes().decode('utf-8')
s = s.tobytes().decode("utf-8")
elif isinstance(s, bytearray):
s = s.decode('utf-8')
s = s.decode("utf-8")
elif decode_bytes and isinstance(s, bytes):
s = s.decode('utf-8')

try:
return _loads(s)
except _DecodeError:
# catch "Unpaired high surrogate" error
return stdjson.loads(s)
s = s.decode("utf-8")

return _loads(s, object_hook=object_hook)


DecoderT = EncoderT = Callable[[Any], Any]
T = TypeVar("T")
EncodedT = TypeVar("EncodedT")


def register_type(
t: type[T],
marker: str,
encoder: Callable[[T], EncodedT],
decoder: Callable[[EncodedT], T],
):
"""Add support for serializing/deserializing native python type."""
_encoders[t] = (marker, encoder)
_decoders[marker] = decoder


_encoders: dict[type, tuple[str, EncoderT]] = {}
_decoders: dict[str, DecoderT] = {
"bytes": lambda o: o.encode("utf-8"),
"base64": lambda o: base64.b64decode(o.encode("utf-8")),
}

# NOTE: datetime should be registered before date,
# because datetime is also instance of date.
register_type(datetime, "datetime", datetime.datetime.isoformat, datetime.datetime.fromisoformat)
register_type(
datetime.date,
"date",
lambda o: o.isoformat(),
lambda o: datetime.datetime.fromisoformat(o).date(),
)
register_type(datetime.time, "time", lambda o: o.isoformat(), datetime.time.fromisoformat)
register_type(Decimal, "decimal", str, Decimal)
register_type(
uuid.UUID,
"uuid",
lambda o: {"hex": o.hex},
lambda o: uuid.UUID(**o),
)
Comment on lines +127 to +143
Copy link

@LincolnPuzeyCC LincolnPuzeyCC Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given this is our fork of kombu you could remove this part since these registrations are overwritten by kraken anyway - and these contain the behavior we don't want - kombu de-coding into typed objects rather than strings.

2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytz>dev
pytest~=7.0.1
pytest-sugar
Pyro4
backports.zoneinfo>=0.2.1; python_version < '3.9'
8 changes: 6 additions & 2 deletions t/unit/utils/test_json.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from collections import namedtuple
from datetime import datetime
from decimal import Decimal
import sys
from unittest.mock import MagicMock, Mock
from uuid import uuid4

import pytest
import pytz

from kombu.utils.encoding import str_to_bytes
from kombu.utils.json import _DecodeError, dumps, loads

if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo

class Custom:

Expand All @@ -24,7 +28,7 @@ class test_JSONEncoder:

def test_datetime(self):
now = datetime.utcnow()
now_utc = now.replace(tzinfo=pytz.utc)
now_utc = now.replace(tzinfo=ZoneInfo("UTC"))
stripped = datetime(*now.timetuple()[:3])
serialized = loads(dumps({
'datetime': now,
Expand Down