Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat new conditional processing #416

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion chatsky/core/script_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ class BaseProcessing(BaseScriptFunc, ABC):
and :py:attr:`chatsky.core.script.Node.pre_response`.
"""

start_condition: AnyCondition = Field(default=True, validate_default=True)
"""
:py:data:`~.AnyCondition` is checked before __call__;
__call__ is initiated only if start_condition returns ``True``.
"""

return_type: ClassVar[Union[type, Tuple[type, ...]]] = type(None)

@abstractmethod
Expand All @@ -211,7 +217,10 @@ async def wrapped_call(self, ctx: Context, *, info: str = "") -> Union[None, Exc
return await super().wrapped_call(ctx, info=info)

async def __call__(self, ctx: Context) -> None:
return await super().__call__(ctx)
if await self.start_condition.is_true(ctx):
return await super().__call__(ctx)
else:
return logger.debug(f"{self.__class__.__name__} not called: self.start_condition = {self.start_condition}")


class BasePriority(BaseScriptFunc, ABC):
Expand Down
34 changes: 33 additions & 1 deletion tests/core/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from chatsky import proc, Context, BaseResponse, MessageInitTypes, Message
from chatsky import proc, Context, BaseResponse, MessageInitTypes, Message, BaseProcessing
from chatsky.core.script import Node


Expand All @@ -22,3 +22,35 @@ async def modified_response(self, original_response: BaseResponse, ctx: Context)
assert ctx.current_node.response.__class__.__name__ == "ModifiedResponse"

assert await ctx.current_node.response(ctx) == Message(misc={"msg": Message("hi")})


class TestConditionalResponce:
async def test_conditional_response(self):
ctx = Context()
ctx.framework_data.current_node = Node()
some_list = []

class SomeProcessing(BaseProcessing):
async def call(self, ctx: Context):
some_list.append("")

await SomeProcessing()(ctx)
assert some_list == [""]

await SomeProcessing().wrapped_call(ctx)
assert some_list == ["", ""]

async def test_conditional_processing_false_condition(self):
ctx = Context()
ctx.framework_data.current_node = Node()
some_list = []

class SomeProcessing(BaseProcessing):
async def call(self, ctx: Context):
some_list.append("")

await SomeProcessing(start_condition=False)(ctx)
assert some_list == []

await SomeProcessing(start_condition=False).wrapped_call(ctx)
assert some_list == []
4 changes: 2 additions & 2 deletions tests/core/test_script_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ async def call(self, ctx):
raise RuntimeError()

assert isinstance(await MyProc().wrapped_call(None), RuntimeError)
assert len(log_list) == 1
assert log_list[0].levelname == "ERROR"
assert len(log_list) == 2
assert log_list[1].levelname == "ERROR"

async def test_base_exception_not_handled(self):
class SpecialException(BaseException):
Expand Down
Loading