Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow unserializable #408

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chatsky/__rebuild_pydantic_models__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chatsky.slots.slots import SlotManager
from chatsky.core.context import FrameworkData, ServiceState
from chatsky.core.service import PipelineComponent
from chatsky.utils.devel import PydanticValue

PipelineComponent.model_rebuild()
Pipeline.model_rebuild()
Expand All @@ -15,3 +16,4 @@
ExtraHandlerRuntimeInfo.model_rebuild()
FrameworkData.model_rebuild()
ServiceState.model_rebuild()
PydanticValue.update_forward_refs()
7 changes: 4 additions & 3 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import logging
import asyncio
from uuid import UUID, uuid4
from typing import Any, Optional, Union, Dict, TYPE_CHECKING
from typing import Optional, Union, Dict, TYPE_CHECKING

from pydantic import BaseModel, Field

from chatsky.utils.devel import PydanticValue
from chatsky.core.message import Message, MessageInitTypes
from chatsky.slots.slots import SlotManager
from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes
Expand Down Expand Up @@ -87,7 +88,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True):
Instance of the pipeline that manages this context.
Can be used to obtain run configuration such as script or fallback label.
"""
stats: Dict[str, Any] = Field(default_factory=dict)
stats: Dict[str, PydanticValue] = Field(default_factory=dict)
"Enables complex stats collection across multiple turns."
slot_manager: SlotManager = Field(default_factory=SlotManager)
"Stores extracted slots."
Expand Down Expand Up @@ -133,7 +134,7 @@ class Context(BaseModel):
First response is stored at key ``1``.
IDs go up by ``1`` after that.
"""
misc: Dict[str, Any] = Field(default_factory=dict)
misc: Dict[str, PydanticValue] = Field(default_factory=dict)
"""
``misc`` stores any custom data. The framework doesn't use this dictionary,
so storage of any data won't reflect on the work of the internal Chatsky functions.
Expand Down
68 changes: 10 additions & 58 deletions chatsky/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,23 @@
"""

from __future__ import annotations
from typing import Literal, Optional, List, Union, Dict, Any, TYPE_CHECKING
from typing import Literal, Optional, List, Union, Dict, TYPE_CHECKING
from typing_extensions import TypeAlias, Annotated
from pathlib import Path
from urllib.request import urlopen
import uuid
import abc

from pydantic import Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer
from pydantic import BaseModel, Field, FilePath, HttpUrl, model_validator
from pydantic_core import Url

from chatsky.utils.devel import (
json_pickle_validator,
json_pickle_serializer,
pickle_serializer,
pickle_validator,
JSONSerializableExtras,
)
from chatsky.utils.devel import PydanticValue

if TYPE_CHECKING:
from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments


class DataModel(JSONSerializableExtras):
class DataModel(BaseModel, extra="allow"):
"""
This class is a Pydantic BaseModel that can have any type and number of extras.
"""
Expand Down Expand Up @@ -290,9 +284,9 @@ class level variables to store message information.
]
]
] = None
annotations: Optional[Dict[str, Any]] = None
misc: Optional[Dict[str, Any]] = None
original_message: Optional[Any] = None
annotations: Optional[Dict[str, PydanticValue]] = None
misc: Optional[Dict[str, PydanticValue]] = None
original_message: Optional[PydanticValue] = None

def __init__( # this allows initializing Message with string as positional argument
self,
Expand All @@ -318,9 +312,9 @@ def __init__( # this allows initializing Message with string as positional argu
]
]
] = None,
annotations: Optional[Dict[str, Any]] = None,
misc: Optional[Dict[str, Any]] = None,
original_message: Optional[Any] = None,
annotations: Optional[Dict[str, PydanticValue]] = None,
misc: Optional[Dict[str, PydanticValue]] = None,
original_message: Optional[PydanticValue] = None,
**kwargs,
):
super().__init__(
Expand All @@ -332,48 +326,6 @@ def __init__( # this allows initializing Message with string as positional argu
**kwargs,
)

@field_serializer("annotations", "misc", when_used="json")
def pickle_serialize_dicts(self, value):
"""
Serialize values that are not json-serializable via pickle.
Allows storing arbitrary data in misc/annotations when using context storages.
"""
if isinstance(value, dict):
return json_pickle_serializer(value)
return value

@field_validator("annotations", "misc", mode="before")
@classmethod
def pickle_validate_dicts(cls, value):
"""Restore values serialized with :py:meth:`pickle_serialize_dicts`."""
if isinstance(value, dict):
return json_pickle_validator(value)
return value

@field_serializer("original_message", when_used="json")
def pickle_serialize_original_message(self, value):
"""
Cast :py:attr:`original_message` to string via pickle.
Allows storing arbitrary data in this field when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("original_message", mode="before")
@classmethod
def pickle_validate_original_message(cls, value):
"""
Restore :py:attr:`original_message` after being processed with
:py:meth:`pickle_serialize_original_message`.
"""
if value is not None:
return pickle_validator(value)
return value

def __str__(self) -> str:
return " ".join([f"{key}='{value}'" for key, value in self.model_dump(exclude_none=True).items()])

@model_validator(mode="before")
@classmethod
def validate_from_str(cls, data):
Expand Down
2 changes: 1 addition & 1 deletion chatsky/messengers/telegram/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ async def _on_event(self, update: Update, _: Any, create_message: Callable[[Upda
data_available = update.message is not None or update.callback_query is not None
if update.effective_chat is not None and data_available:
message = create_message(update)
message.original_message = update
message.original_message = update.to_dict(recursive=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to apply to extra fields in Attachments.
Add a validator for Attachment and Message extras that modifies the extra field via to_dict if the field is of the TelegramObject value.

AFAIK if the extra field value is a dictionary from to_dict it should still work for the tg bot methods.

resp = await self._pipeline_runner(message, update.effective_chat.id)
if resp.last_response is not None:
await self.cast_message_to_telegram_and_send(
Expand Down
45 changes: 13 additions & 32 deletions chatsky/slots/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
import asyncio
import re
from abc import ABC, abstractmethod
from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing import Callable, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing_extensions import TypeAlias, Annotated
import logging
from functools import reduce
from string import Formatter

from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator
from pydantic import BaseModel, model_validator, Field
from pydantic.dataclasses import dataclass

from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async
from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator
from chatsky.utils.devel import PydanticValue

if TYPE_CHECKING:
from chatsky.core import Context, Message
Expand Down Expand Up @@ -82,6 +83,7 @@ def recursive_setattr(obj, slot_name: SlotName, value):
setattr(parent_obj, slot, value)


@dataclass
class SlotNotExtracted(Exception):
"""This exception can be returned or raised by slot extractor if slot extraction is unsuccessful."""

Expand Down Expand Up @@ -117,29 +119,8 @@ class ExtractedValueSlot(ExtractedSlot):
"""Value extracted from :py:class:`~.ValueSlot`."""

is_slot_extracted: bool
extracted_value: Any
default_value: Any = None

@field_serializer("extracted_value", "default_value", when_used="json")
def pickle_serialize_values(self, value):
"""
Cast values to string via pickle.
Allows storing arbitrary data in these fields when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("extracted_value", "default_value", mode="before")
@classmethod
def pickle_validate_values(cls, value):
"""
Restore values after being processed with
:py:meth:`pickle_serialize_values`.
"""
if value is not None:
return pickle_validator(value)
return value
extracted_value: PydanticValue
default_value: Optional[PydanticValue] = None

@property
def __slot_extracted__(self) -> bool:
Expand Down Expand Up @@ -219,10 +200,10 @@ class ValueSlot(BaseSlot, frozen=True):
Subclass it, if you want to declare your own slot type.
"""

default_value: Any = None
default_value: PydanticValue = None

@abstractmethod
async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]:
async def extract_value(self, ctx: Context) -> Union[PydanticValue, SlotNotExtracted]:
"""
Return value extracted from context.

Expand All @@ -246,14 +227,14 @@ async def get_value(self, ctx: Context) -> ExtractedValueSlot:
finally:
if not is_slot_extracted:
logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}")
return ExtractedValueSlot.model_construct(
return ExtractedValueSlot(
is_slot_extracted=is_slot_extracted,
extracted_value=extracted_value,
default_value=self.default_value,
)

def init_value(self) -> ExtractedValueSlot:
return ExtractedValueSlot.model_construct(
return ExtractedValueSlot(
is_slot_extracted=False,
extracted_value=SlotNotExtracted("Initial slot extraction."),
default_value=self.default_value,
Expand Down Expand Up @@ -328,9 +309,9 @@ class FunctionSlot(ValueSlot, frozen=True):
Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message.
"""

func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]]
func: Callable[[Message], Union[Awaitable[Union[PydanticValue, SlotNotExtracted]], PydanticValue, SlotNotExtracted]]

async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]:
async def extract_value(self, ctx: Context) -> Union[PydanticValue, SlotNotExtracted]:
return await wrap_sync_function_in_async(self.func, ctx.last_request)


Expand Down
8 changes: 1 addition & 7 deletions chatsky/utils/devel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
parts of the framework.
"""

from .json_serialization import (
json_pickle_serializer,
json_pickle_validator,
pickle_serializer,
pickle_validator,
JSONSerializableExtras,
)
from .extra_field_helpers import grab_extra_fields
from .async_helpers import wrap_sync_function_in_async
from .serialization import PydanticValue
Loading