From d8b87945215f10f2acd7007501382d66c93e58f6 Mon Sep 17 00:00:00 2001 From: pseusys Date: Sat, 16 Nov 2024 00:00:59 +0800 Subject: [PATCH 1/3] pickle-json mixture removed --- chatsky/core/context.py | 8 +- chatsky/core/message.py | 68 ++------- chatsky/messengers/telegram/abstract.py | 2 +- chatsky/slots/slots.py | 38 ++--- chatsky/utils/devel/json_serialization.py | 154 --------------------- tests/core/test_message.py | 25 ++-- tests/utils/test_serialization.py | 160 ---------------------- 7 files changed, 31 insertions(+), 424 deletions(-) delete mode 100644 chatsky/utils/devel/json_serialization.py delete mode 100644 tests/utils/test_serialization.py diff --git a/chatsky/core/context.py b/chatsky/core/context.py index f0c03d3da..f36e9c661 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -20,9 +20,9 @@ import logging import asyncio from uuid import UUID, uuid4 -from typing import Any, Optional, Union, Dict, TYPE_CHECKING +from typing import Optional, Union, Dict, TYPE_CHECKING -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue from chatsky.core.message import Message, MessageInitTypes from chatsky.slots.slots import SlotManager @@ -87,7 +87,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True): Instance of the pipeline that manages this context. Can be used to obtain run configuration such as script or fallback label. """ - stats: Dict[str, Any] = Field(default_factory=dict) + stats: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict) "Enables complex stats collection across multiple turns." slot_manager: SlotManager = Field(default_factory=SlotManager) "Stores extracted slots." @@ -133,7 +133,7 @@ class Context(BaseModel): First response is stored at key ``1``. IDs go up by ``1`` after that. """ - misc: Dict[str, Any] = Field(default_factory=dict) + misc: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict) """ ``misc`` stores any custom data. The framework doesn't use this dictionary, so storage of any data won't reflect on the work of the internal Chatsky functions. diff --git a/chatsky/core/message.py b/chatsky/core/message.py index 24a0c7e73..74236058e 100644 --- a/chatsky/core/message.py +++ b/chatsky/core/message.py @@ -7,29 +7,21 @@ """ from __future__ import annotations -from typing import Literal, Optional, List, Union, Dict, Any, TYPE_CHECKING +from typing import Literal, Optional, List, Union, Dict, TYPE_CHECKING from typing_extensions import TypeAlias, Annotated from pathlib import Path from urllib.request import urlopen import uuid import abc -from pydantic import Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer +from pydantic import BaseModel, Field, FilePath, HttpUrl, JsonValue, model_validator from pydantic_core import Url -from chatsky.utils.devel import ( - json_pickle_validator, - json_pickle_serializer, - pickle_serializer, - pickle_validator, - JSONSerializableExtras, -) - if TYPE_CHECKING: from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments -class DataModel(JSONSerializableExtras): +class DataModel(BaseModel, extra="allow"): """ This class is a Pydantic BaseModel that can have any type and number of extras. """ @@ -290,9 +282,9 @@ class level variables to store message information. ] ] ] = None - annotations: Optional[Dict[str, Any]] = None - misc: Optional[Dict[str, Any]] = None - original_message: Optional[Any] = None + annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None + misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None + original_message: Optional[Union[BaseModel, JsonValue]] = None def __init__( # this allows initializing Message with string as positional argument self, @@ -318,9 +310,9 @@ def __init__( # this allows initializing Message with string as positional argu ] ] ] = None, - annotations: Optional[Dict[str, Any]] = None, - misc: Optional[Dict[str, Any]] = None, - original_message: Optional[Any] = None, + annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None, + misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None, + original_message: Optional[Union[BaseModel, JsonValue]] = None, **kwargs, ): super().__init__( @@ -332,48 +324,6 @@ def __init__( # this allows initializing Message with string as positional argu **kwargs, ) - @field_serializer("annotations", "misc", when_used="json") - def pickle_serialize_dicts(self, value): - """ - Serialize values that are not json-serializable via pickle. - Allows storing arbitrary data in misc/annotations when using context storages. - """ - if isinstance(value, dict): - return json_pickle_serializer(value) - return value - - @field_validator("annotations", "misc", mode="before") - @classmethod - def pickle_validate_dicts(cls, value): - """Restore values serialized with :py:meth:`pickle_serialize_dicts`.""" - if isinstance(value, dict): - return json_pickle_validator(value) - return value - - @field_serializer("original_message", when_used="json") - def pickle_serialize_original_message(self, value): - """ - Cast :py:attr:`original_message` to string via pickle. - Allows storing arbitrary data in this field when using context storages. - """ - if value is not None: - return pickle_serializer(value) - return value - - @field_validator("original_message", mode="before") - @classmethod - def pickle_validate_original_message(cls, value): - """ - Restore :py:attr:`original_message` after being processed with - :py:meth:`pickle_serialize_original_message`. - """ - if value is not None: - return pickle_validator(value) - return value - - def __str__(self) -> str: - return " ".join([f"{key}='{value}'" for key, value in self.model_dump(exclude_none=True).items()]) - @model_validator(mode="before") @classmethod def validate_from_str(cls, data): diff --git a/chatsky/messengers/telegram/abstract.py b/chatsky/messengers/telegram/abstract.py index 1d464a4a7..f4e829be8 100644 --- a/chatsky/messengers/telegram/abstract.py +++ b/chatsky/messengers/telegram/abstract.py @@ -627,7 +627,7 @@ async def _on_event(self, update: Update, _: Any, create_message: Callable[[Upda data_available = update.message is not None or update.callback_query is not None if update.effective_chat is not None and data_available: message = create_message(update) - message.original_message = update + message.original_message = update.to_dict(recursive=True) resp = await self._pipeline_runner(message, update.effective_chat.id) if resp.last_response is not None: await self.cast_message_to_telegram_and_send( diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 3cadd9205..cc9e36bf1 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -9,16 +9,15 @@ import asyncio import re from abc import ABC, abstractmethod -from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict +from typing import Callable, Awaitable, TYPE_CHECKING, Union, Optional, Dict from typing_extensions import TypeAlias, Annotated import logging from functools import reduce from string import Formatter -from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator +from pydantic import BaseModel, JsonValue, model_validator, Field from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async -from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator if TYPE_CHECKING: from chatsky.core import Context, Message @@ -117,29 +116,8 @@ class ExtractedValueSlot(ExtractedSlot): """Value extracted from :py:class:`~.ValueSlot`.""" is_slot_extracted: bool - extracted_value: Any - default_value: Any = None - - @field_serializer("extracted_value", "default_value", when_used="json") - def pickle_serialize_values(self, value): - """ - Cast values to string via pickle. - Allows storing arbitrary data in these fields when using context storages. - """ - if value is not None: - return pickle_serializer(value) - return value - - @field_validator("extracted_value", "default_value", mode="before") - @classmethod - def pickle_validate_values(cls, value): - """ - Restore values after being processed with - :py:meth:`pickle_serialize_values`. - """ - if value is not None: - return pickle_validator(value) - return value + extracted_value: Union[BaseModel, JsonValue] + default_value: Optional[Union[BaseModel, JsonValue]] = None @property def __slot_extracted__(self) -> bool: @@ -219,10 +197,10 @@ class ValueSlot(BaseSlot, frozen=True): Subclass it, if you want to declare your own slot type. """ - default_value: Any = None + default_value: Union[BaseModel, JsonValue] = None @abstractmethod - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]: """ Return value extracted from context. @@ -328,9 +306,9 @@ class FunctionSlot(ValueSlot, frozen=True): Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. """ - func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]] + func: Callable[[Message], Union[Awaitable[Union[Union[BaseModel, JsonValue], SlotNotExtracted]], Union[BaseModel, JsonValue], SlotNotExtracted]] - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]: return await wrap_sync_function_in_async(self.func, ctx.last_request) diff --git a/chatsky/utils/devel/json_serialization.py b/chatsky/utils/devel/json_serialization.py deleted file mode 100644 index 132e79f65..000000000 --- a/chatsky/utils/devel/json_serialization.py +++ /dev/null @@ -1,154 +0,0 @@ -""" -Serialization -------------- -Tools that provide JSON serialization via Pickle for unserializable objects. - -- :py:data:`~.PickleEncodedValue`: - A field annotated with this will be pickled/unpickled during JSON-serialization/validation. -- :py:data:`~.JSONSerializableDict`: - A dictionary field annotated with this will make all its items smart-serializable: - If an item is serializable -- nothing would change. - Otherwise -- it will be serialized via pickle. -- :py:class:`~.JSONSerializableExtras`: - A pydantic base class that makes its extra fields a `JSONSerializableDict`. -""" - -from base64 import decodebytes, encodebytes -from copy import deepcopy -from pickle import dumps, loads -from typing import Any, Dict, List, Union -from typing_extensions import TypeAlias -from pydantic import ( - JsonValue, - RootModel, - BaseModel, - model_validator, - model_serializer, -) -from pydantic_core import PydanticSerializationError - -_JSON_EXTRA_FIELDS_KEYS = "__pickled_extra_fields__" -""" -This key is used in :py:data:`~.JSONSerializableDict` to remember pickled items. -""" - -Serializable: TypeAlias = Dict[str, Union[JsonValue, List[Any], Dict[str, Any], Any]] -"""Type annotation for objects supported by :py:func:`~.json_pickle_serializer`.""" - - -class _WrapperModel(RootModel): - """ - Wrapper model for testing whether an object is serializable to JSON. - """ - - root: Any - - -def pickle_serializer(value: Any) -> str: - """ - Serializer function that serializes any pickle-serializable value into JSON-serializable. - Serializes value with pickle and encodes bytes as base64 string. - - :param value: Pickle-serializable object. - :return: String-encoded object. - """ - - return encodebytes(dumps(value)).decode() - - -def pickle_validator(value: str) -> Any: - """ - Validator function that validates base64 string encoded bytes as a pickle-serializable value. - Decodes base64 string and validates value with pickle. - - :param value: String-encoded string. - :return: Pickle-serializable object. - """ - - return loads(decodebytes(value.encode())) - - -def json_pickle_serializer(model: Serializable) -> Serializable: - """ - Serializer function that serializes a dictionary or Pydantic object to JSON. - For every object field, it checks whether the field is JSON serializable, - and if it's not, serializes it using pickle. - It also keeps track of pickle-serializable field names in a special list. - - :param model: Pydantic model object or a dictionary. - :original_serializer: Original serializer function for model. - :return: model with all the fields serialized to JSON. - """ - - extra_fields = list() - model_copy = deepcopy(model) - - for field_name, field_value in model_copy.items(): - try: - if isinstance(field_value, bytes): - raise PydanticSerializationError("") - else: - model_copy[field_name] = _WrapperModel(root=field_value).model_dump(mode="json") - except PydanticSerializationError: - model_copy[field_name] = pickle_serializer(field_value) - extra_fields += [field_name] - - if len(extra_fields) > 0: - model_copy[_JSON_EXTRA_FIELDS_KEYS] = extra_fields - return model_copy - - -def json_pickle_validator(model: Serializable) -> Serializable: - """ - Validator function that validates a JSON dictionary to a python dictionary. - For every object field, it checks if the field is pickle-serialized, - and if it is, validates it using pickle. - - :param model: Pydantic model object or a dictionary. - :return: model with all the fields serialized to JSON. - """ - - model_copy = deepcopy(model) - - if _JSON_EXTRA_FIELDS_KEYS in model.keys(): - for extra_key in model[_JSON_EXTRA_FIELDS_KEYS]: - extra_value = model[extra_key] - model_copy[extra_key] = pickle_validator(extra_value) - del model_copy[_JSON_EXTRA_FIELDS_KEYS] - - return model_copy - - -class JSONSerializableExtras(BaseModel, extra="allow"): - """ - This model makes extra fields pickle-serializable. - Do not use :py:data:`~._JSON_EXTRA_FIELDS_KEYS` as an extra field name. - """ - - def __init__(self, **kwargs): # supress unknown arg warnings - super().__init__(**kwargs) - - @model_validator(mode="after") - def extra_validator(self): - """ - Validate model along with the `extras` field: i.e. all the fields not listed in the model. - - :return: Validated model. - """ - self.__pydantic_extra__ = json_pickle_validator(self.__pydantic_extra__) - return self - - @model_serializer(mode="wrap", when_used="json") - def extra_serializer(self, original_serializer) -> Dict[str, Any]: - """ - Serialize model along with the `extras` field: i.e. all the fields not listed in the model. - - :param original_serializer: Function originally used for serialization by Pydantic. - :return: Serialized model. - """ - model_copy = self.model_copy(deep=True) - for extra_name in self.model_extra.keys(): - delattr(model_copy, extra_name) - model_dict = original_serializer(model_copy) - model_dict.update(json_pickle_serializer(self.model_extra)) - return model_dict diff --git a/tests/core/test_message.py b/tests/core/test_message.py index 6b8b7d15f..4b7e0ded5 100644 --- a/tests/core/test_message.py +++ b/tests/core/test_message.py @@ -6,7 +6,7 @@ from urllib.request import urlopen import pytest -from pydantic import ValidationError, HttpUrl, FilePath +from pydantic import BaseModel, ValidationError, HttpUrl, FilePath from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments from chatsky.messengers.console import CLIMessengerInterface @@ -30,16 +30,9 @@ EXAMPLE_SOURCE = "https://github.com/deeppavlov/chatsky/wiki/example_attachments" -class UnserializableObject: - def __init__(self, number: int, string: bytes) -> None: - self.number = number - self.bytes = string - - def __eq__(self, value: object) -> bool: - if isinstance(value, UnserializableObject): - return self.number == value.number and self.bytes == value.bytes - else: - return False +class SampleOriginalMessage(BaseModel): + num: int + bts: bytes class ChatskyCLIMessengerInterface(CLIMessengerInterface, MessengerInterfaceWithAttachments): @@ -60,8 +53,8 @@ async def get_attachment_bytes(self, attachment: str) -> bytes: class TestMessage: @pytest.fixture - def random_original_message(self) -> UnserializableObject: - return UnserializableObject(randint(0, 256), urandom(32)) + def random_original_message(self) -> SampleOriginalMessage: + return SampleOriginalMessage(num=randint(0, 256), bts=urandom(32)) def clear_and_create_dir(self, dir: Path) -> Path: rmtree(dir, ignore_errors=True) @@ -90,12 +83,12 @@ def test_attachment_serialize(self, attachment: DataAttachment): validated = Message.model_validate_json(serialized) assert message == validated - def test_field_serializable(self, random_original_message: UnserializableObject): + def test_field_serializable(self, random_original_message: SampleOriginalMessage): message = Message(text="sample message") - message.misc = {"answer": 42, "unserializable": random_original_message} + message.misc = {"answer": 42, "original": random_original_message} message.original_message = random_original_message message.some_extra_field = random_original_message - message.other_extra_field = {"unserializable": random_original_message} + message.other_extra_field = {"original": random_original_message} serialized = message.model_dump_json() validated = Message.model_validate_json(serialized) assert message == validated diff --git a/tests/utils/test_serialization.py b/tests/utils/test_serialization.py deleted file mode 100644 index 7765f4d3d..000000000 --- a/tests/utils/test_serialization.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Optional, Dict, Any - -import pytest -from pydantic import BaseModel, field_serializer, field_validator -from copy import deepcopy - -import chatsky.utils.devel.json_serialization as json_ser - - -class UnserializableClass: - def __init__(self): - self.exc = RuntimeError("exception") - - def __eq__(self, other): - if not isinstance(other, UnserializableClass): - return False - return type(self.exc) == type(other.exc) and self.exc.args == other.exc.args # noqa: E721 - - -class PydanticClass(BaseModel, arbitrary_types_allowed=True): - field: Optional[UnserializableClass] - - -class TestJSONPickleSerialization: - @pytest.fixture(scope="function") - def unserializable_obj(self): - return UnserializableClass() - - ######################### - # DICT-RELATED FIXTURES # - ######################### - - @pytest.fixture(scope="function") - def unserializable_dict(self, unserializable_obj): - return { - "bytes": b"123", - "non_pydantic_non_serializable": unserializable_obj, - "non_pydantic_serializable": "string", - "pydantic_non_serializable": PydanticClass(field=unserializable_obj), - "pydantic_serializable": PydanticClass(field=None), - } - - @pytest.fixture(scope="function") - def non_serializable_fields(self): - return ["bytes", "non_pydantic_non_serializable", "pydantic_non_serializable"] - - @pytest.fixture(scope="function") - def deserialized_dict(self, unserializable_obj): - return { - "bytes": b"123", - "non_pydantic_non_serializable": unserializable_obj, - "non_pydantic_serializable": "string", - "pydantic_non_serializable": PydanticClass(field=unserializable_obj), - "pydantic_serializable": {"field": None}, - } - - ######################### - ######################### - ######################### - - def test_pickle(self, unserializable_obj): - serialized = json_ser.pickle_serializer(unserializable_obj) - assert isinstance(serialized, str) - assert json_ser.pickle_validator(serialized) == unserializable_obj - - def test_json_pickle(self, unserializable_dict, non_serializable_fields, deserialized_dict): - dict_copy = deepcopy(unserializable_dict) - - serialized = json_ser.json_pickle_serializer(dict_copy) - - assert dict_copy == unserializable_dict, "Dict changed by serializer" - - assert serialized[json_ser._JSON_EXTRA_FIELDS_KEYS] == non_serializable_fields - assert all(isinstance(serialized[field], str) for field in non_serializable_fields) - assert serialized["non_pydantic_serializable"] == "string" - assert serialized["pydantic_serializable"] == {"field": None} - - deserialized = json_ser.json_pickle_validator(serialized) - assert deserialized == deserialized_dict - - def test_serializable_value(self, unserializable_obj): - class Class(BaseModel): - field: Optional[Any] = None - - @field_serializer("field", when_used="json") - def pickle_serialize_field(self, value): - if value is not None: - return json_ser.pickle_serializer(value) - return value - - @field_validator("field", mode="before") - @classmethod - def pickle_validate_field(cls, value): - if value is not None: - return json_ser.pickle_validator(value) - return value - - obj = Class() - obj.field = unserializable_obj - - obj_copy = obj.model_copy(deep=True) - - dump = obj_copy.model_dump(mode="json") - - assert obj == obj_copy, "Object changed by serializer" - - assert isinstance(dump["field"], str) - - reconstructed_obj = Class.model_validate(dump) - - assert reconstructed_obj.field == unserializable_obj - - def test_serializable_dict(self, unserializable_dict, non_serializable_fields, deserialized_dict): - class Class(BaseModel): - field: Optional[Dict[str, Any]] = None - - @field_serializer("field", when_used="json") - def pickle_serialize_dicts(self, value): - if isinstance(value, dict): - return json_ser.json_pickle_serializer(value) - return value - - @field_validator("field", mode="before") - @classmethod - def pickle_validate_dicts(cls, value): - if isinstance(value, dict): - return json_ser.json_pickle_validator(value) - return value - - obj = Class(field=unserializable_dict) - - obj_copy = obj.model_copy(deep=True) - - dump = obj_copy.model_dump(mode="json") - - assert obj == obj_copy, "Object changed by serializer" - - assert dump["field"][json_ser._JSON_EXTRA_FIELDS_KEYS] == non_serializable_fields - - reconstructed_obj = Class.model_validate(dump) - - assert reconstructed_obj.field == deserialized_dict - - def test_serializable_extras(self, unserializable_dict, non_serializable_fields, deserialized_dict): - class Class(json_ser.JSONSerializableExtras): - pass - - obj = Class(**unserializable_dict) - - obj_copy = obj.model_copy(deep=True) - - dump = obj_copy.model_dump(mode="json") - - assert obj == obj_copy, "Object changed by serializer" - - assert dump[json_ser._JSON_EXTRA_FIELDS_KEYS] == non_serializable_fields - - reconstructed_obj = Class.model_validate(dump) - - assert reconstructed_obj.__pydantic_extra__ == deserialized_dict From 3182544c75afb5098404773ba9a038048cf6b09e Mon Sep 17 00:00:00 2001 From: pseusys Date: Mon, 18 Nov 2024 21:17:08 +0800 Subject: [PATCH 2/3] model_construct calls removed --- chatsky/slots/slots.py | 4 ++-- tests/slots/test_slot_manager.py | 20 +++++++------------- tests/slots/test_slot_types.py | 28 +++++++++++++--------------- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index cc9e36bf1..910978f29 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -224,14 +224,14 @@ async def get_value(self, ctx: Context) -> ExtractedValueSlot: finally: if not is_slot_extracted: logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") - return ExtractedValueSlot.model_construct( + return ExtractedValueSlot( is_slot_extracted=is_slot_extracted, extracted_value=extracted_value, default_value=self.default_value, ) def init_value(self) -> ExtractedValueSlot: - return ExtractedValueSlot.model_construct( + return ExtractedValueSlot( is_slot_extracted=False, extracted_value=SlotNotExtracted("Initial slot extraction."), default_value=self.default_value, diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 33885b360..0be674090 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -16,7 +16,7 @@ def faulty_func(_): raise SlotNotExtracted("Error.") -init_value_slot = ExtractedValueSlot.model_construct( +init_value_slot = ExtractedValueSlot( is_slot_extracted=False, extracted_value=SlotNotExtracted("Initial slot extraction."), default_value=None, @@ -34,16 +34,12 @@ def faulty_func(_): extracted_slot_values = { - "person.name": ExtractedValueSlot.model_construct( - is_slot_extracted=True, extracted_value="Bot", default_value=None - ), - "person.surname": ExtractedValueSlot.model_construct( + "person.name": ExtractedValueSlot(is_slot_extracted=True, extracted_value="Bot", default_value=None), + "person.surname": ExtractedValueSlot( is_slot_extracted=False, extracted_value=SlotNotExtracted("Error."), default_value=None ), - "person.email": ExtractedValueSlot.model_construct( - is_slot_extracted=True, extracted_value="bot@bot", default_value=None - ), - "msg_len": ExtractedValueSlot.model_construct(is_slot_extracted=True, extracted_value=29, default_value=None), + "person.email": ExtractedValueSlot(is_slot_extracted=True, extracted_value="bot@bot", default_value=None), + "msg_len": ExtractedValueSlot(is_slot_extracted=True, extracted_value=29, default_value=None), } @@ -54,7 +50,7 @@ def faulty_func(_): ) -unset_slot = ExtractedValueSlot.model_construct( +unset_slot = ExtractedValueSlot( is_slot_extracted=False, extracted_value=SlotNotExtracted("Slot manually unset."), default_value=None ) @@ -116,9 +112,7 @@ def extracted_slot_manager(): @pytest.fixture(scope="function") def fully_extracted_slot_manager(): slot_storage = full_slot_storage.model_copy(deep=True) - slot_storage.person.surname = ExtractedValueSlot.model_construct( - extracted_value="Bot", is_slot_extracted=True, default_value=None - ) + slot_storage.person.surname = ExtractedValueSlot(extracted_value="Bot", is_slot_extracted=True, default_value=None) return SlotManager(root_slot=root_slot, slot_storage=slot_storage) diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index a21cbd896..f95d7215f 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -18,12 +18,12 @@ ( Message(text="My name is Bot"), "(?<=name is ).+", - ExtractedValueSlot.model_construct(extracted_value="Bot", is_slot_extracted=True, default_value=None), + ExtractedValueSlot(extracted_value="Bot", is_slot_extracted=True, default_value=None), ), ( Message(text="I won't tell you my name"), "(?<=name is ).+$", - ExtractedValueSlot.model_construct( + ExtractedValueSlot( extracted_value=SlotNotExtracted( "Failed to match pattern {regexp!r} in {request_text!r}.".format( regexp="(?<=name is ).+$", request_text="I won't tell you my name" @@ -48,12 +48,12 @@ async def test_regexp(user_request, regexp, expected, context): ( Message(text="I am bot"), lambda msg: msg.text.split(" ")[2], - ExtractedValueSlot.model_construct(extracted_value="bot", is_slot_extracted=True, default_value=None), + ExtractedValueSlot(extracted_value="bot", is_slot_extracted=True, default_value=None), ), ( Message(text="My email is bot@bot"), lambda msg: [i for i in msg.text.split(" ") if "@" in i][0], - ExtractedValueSlot.model_construct(extracted_value="bot@bot", is_slot_extracted=True, default_value=None), + ExtractedValueSlot(extracted_value="bot@bot", is_slot_extracted=True, default_value=None), ), ], ) @@ -91,10 +91,10 @@ def func(msg: Message): email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), ), ExtractedGroupSlot( - name=ExtractedValueSlot.model_construct( + name=ExtractedValueSlot( is_slot_extracted=True, extracted_value="Bot", default_value=None ), - email=ExtractedValueSlot.model_construct( + email=ExtractedValueSlot( is_slot_extracted=True, extracted_value="bot@bot", default_value=None ), ), @@ -107,10 +107,10 @@ def func(msg: Message): email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), ), ExtractedGroupSlot( - name=ExtractedValueSlot.model_construct( + name=ExtractedValueSlot( is_slot_extracted=True, extracted_value="Bot", default_value=None ), - email=ExtractedValueSlot.model_construct( + email=ExtractedValueSlot( is_slot_extracted=False, extracted_value=SlotNotExtracted( "Failed to match pattern {regexp!r} in {request_text!r}.".format( @@ -139,20 +139,18 @@ def test_group_subslot_name_validation(forbidden_name): async def test_str_representation(): assert ( - str(ExtractedValueSlot.model_construct(is_slot_extracted=True, extracted_value="hello", default_value=None)) - == "hello" + str(ExtractedValueSlot(is_slot_extracted=True, extracted_value="hello", default_value=None)) == "hello" ) assert ( - str(ExtractedValueSlot.model_construct(is_slot_extracted=False, extracted_value=None, default_value="hello")) - == "hello" + str(ExtractedValueSlot(is_slot_extracted=False, extracted_value=None, default_value="hello")) == "hello" ) assert ( str( ExtractedGroupSlot( - first_name=ExtractedValueSlot.model_construct( + first_name=ExtractedValueSlot( is_slot_extracted=True, extracted_value="Tom", default_value="John" ), - last_name=ExtractedValueSlot.model_construct( + last_name=ExtractedValueSlot( is_slot_extracted=False, extracted_value=None, default_value="Smith" ), ) @@ -172,7 +170,7 @@ def __eq__(self, other): async def test_serialization(): - extracted_slot = ExtractedValueSlot.model_construct( + extracted_slot = ExtractedValueSlot( is_slot_extracted=True, extracted_value=UnserializableClass(), default_value=UnserializableClass() ) serialized = extracted_slot.model_dump_json() From 9b8615c69e300d9de1b1a8262f49bff697e1d534 Mon Sep 17 00:00:00 2001 From: pseusys Date: Mon, 18 Nov 2024 23:52:41 +0800 Subject: [PATCH 3/3] attempt to introduce pydantic value --- chatsky/__rebuild_pydantic_models__.py | 2 ++ chatsky/core/context.py | 7 ++++--- chatsky/core/message.py | 16 +++++++++------- chatsky/slots/slots.py | 17 ++++++++++------- chatsky/utils/devel/__init__.py | 8 +------- chatsky/utils/devel/serialization.py | 15 +++++++++++++++ 6 files changed, 41 insertions(+), 24 deletions(-) create mode 100644 chatsky/utils/devel/serialization.py diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 1da4126a9..ad14bc5e0 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -7,6 +7,7 @@ from chatsky.slots.slots import SlotManager from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent +from chatsky.utils.devel import PydanticValue PipelineComponent.model_rebuild() Pipeline.model_rebuild() @@ -15,3 +16,4 @@ ExtraHandlerRuntimeInfo.model_rebuild() FrameworkData.model_rebuild() ServiceState.model_rebuild() +PydanticValue.update_forward_refs() diff --git a/chatsky/core/context.py b/chatsky/core/context.py index f36e9c661..e484c7717 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -22,8 +22,9 @@ from uuid import UUID, uuid4 from typing import Optional, Union, Dict, TYPE_CHECKING -from pydantic import BaseModel, Field, JsonValue +from pydantic import BaseModel, Field +from chatsky.utils.devel import PydanticValue from chatsky.core.message import Message, MessageInitTypes from chatsky.slots.slots import SlotManager from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes @@ -87,7 +88,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True): Instance of the pipeline that manages this context. Can be used to obtain run configuration such as script or fallback label. """ - stats: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict) + stats: Dict[str, PydanticValue] = Field(default_factory=dict) "Enables complex stats collection across multiple turns." slot_manager: SlotManager = Field(default_factory=SlotManager) "Stores extracted slots." @@ -133,7 +134,7 @@ class Context(BaseModel): First response is stored at key ``1``. IDs go up by ``1`` after that. """ - misc: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict) + misc: Dict[str, PydanticValue] = Field(default_factory=dict) """ ``misc`` stores any custom data. The framework doesn't use this dictionary, so storage of any data won't reflect on the work of the internal Chatsky functions. diff --git a/chatsky/core/message.py b/chatsky/core/message.py index 74236058e..9c8a4f92b 100644 --- a/chatsky/core/message.py +++ b/chatsky/core/message.py @@ -14,9 +14,11 @@ import uuid import abc -from pydantic import BaseModel, Field, FilePath, HttpUrl, JsonValue, model_validator +from pydantic import BaseModel, Field, FilePath, HttpUrl, model_validator from pydantic_core import Url +from chatsky.utils.devel import PydanticValue + if TYPE_CHECKING: from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments @@ -282,9 +284,9 @@ class level variables to store message information. ] ] ] = None - annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None - misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None - original_message: Optional[Union[BaseModel, JsonValue]] = None + annotations: Optional[Dict[str, PydanticValue]] = None + misc: Optional[Dict[str, PydanticValue]] = None + original_message: Optional[PydanticValue] = None def __init__( # this allows initializing Message with string as positional argument self, @@ -310,9 +312,9 @@ def __init__( # this allows initializing Message with string as positional argu ] ] ] = None, - annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None, - misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None, - original_message: Optional[Union[BaseModel, JsonValue]] = None, + annotations: Optional[Dict[str, PydanticValue]] = None, + misc: Optional[Dict[str, PydanticValue]] = None, + original_message: Optional[PydanticValue] = None, **kwargs, ): super().__init__( diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 910978f29..4d5dfd3a3 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -15,9 +15,11 @@ from functools import reduce from string import Formatter -from pydantic import BaseModel, JsonValue, model_validator, Field +from pydantic import BaseModel, model_validator, Field +from pydantic.dataclasses import dataclass from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async +from chatsky.utils.devel import PydanticValue if TYPE_CHECKING: from chatsky.core import Context, Message @@ -81,6 +83,7 @@ def recursive_setattr(obj, slot_name: SlotName, value): setattr(parent_obj, slot, value) +@dataclass class SlotNotExtracted(Exception): """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" @@ -116,8 +119,8 @@ class ExtractedValueSlot(ExtractedSlot): """Value extracted from :py:class:`~.ValueSlot`.""" is_slot_extracted: bool - extracted_value: Union[BaseModel, JsonValue] - default_value: Optional[Union[BaseModel, JsonValue]] = None + extracted_value: PydanticValue + default_value: Optional[PydanticValue] = None @property def __slot_extracted__(self) -> bool: @@ -197,10 +200,10 @@ class ValueSlot(BaseSlot, frozen=True): Subclass it, if you want to declare your own slot type. """ - default_value: Union[BaseModel, JsonValue] = None + default_value: PydanticValue = None @abstractmethod - async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]: + async def extract_value(self, ctx: Context) -> Union[PydanticValue, SlotNotExtracted]: """ Return value extracted from context. @@ -306,9 +309,9 @@ class FunctionSlot(ValueSlot, frozen=True): Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. """ - func: Callable[[Message], Union[Awaitable[Union[Union[BaseModel, JsonValue], SlotNotExtracted]], Union[BaseModel, JsonValue], SlotNotExtracted]] + func: Callable[[Message], Union[Awaitable[Union[PydanticValue, SlotNotExtracted]], PydanticValue, SlotNotExtracted]] - async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]: + async def extract_value(self, ctx: Context) -> Union[PydanticValue, SlotNotExtracted]: return await wrap_sync_function_in_async(self.func, ctx.last_request) diff --git a/chatsky/utils/devel/__init__.py b/chatsky/utils/devel/__init__.py index e7227f8c4..8b8a9390a 100644 --- a/chatsky/utils/devel/__init__.py +++ b/chatsky/utils/devel/__init__.py @@ -5,12 +5,6 @@ parts of the framework. """ -from .json_serialization import ( - json_pickle_serializer, - json_pickle_validator, - pickle_serializer, - pickle_validator, - JSONSerializableExtras, -) from .extra_field_helpers import grab_extra_fields from .async_helpers import wrap_sync_function_in_async +from .serialization import PydanticValue diff --git a/chatsky/utils/devel/serialization.py b/chatsky/utils/devel/serialization.py new file mode 100644 index 000000000..785630e8c --- /dev/null +++ b/chatsky/utils/devel/serialization.py @@ -0,0 +1,15 @@ +from typing import Dict, List, TypeAlias, Union + +from pydantic import BaseModel + + +PydanticValue: TypeAlias = Union[ + List["PydanticValue"], + Dict[str, "PydanticValue"], + BaseModel, + str, + bool, + int, + float, + None, +]