diff --git a/chatsky/core/script_function.py b/chatsky/core/script_function.py index 3ebeeb1d9..0dd4b54b6 100644 --- a/chatsky/core/script_function.py +++ b/chatsky/core/script_function.py @@ -201,6 +201,12 @@ class BaseProcessing(BaseScriptFunc, ABC): and :py:attr:`chatsky.core.script.Node.pre_response`. """ + start_condition: AnyCondition = Field(default=True, validate_default=True) + """ + :py:data:`~.AnyCondition` is checked before __call__; + __call__ is initiated only if start_condition returns ``True``. + """ + return_type: ClassVar[Union[type, Tuple[type, ...]]] = type(None) @abstractmethod @@ -211,7 +217,10 @@ async def wrapped_call(self, ctx: Context, *, info: str = "") -> Union[None, Exc return await super().wrapped_call(ctx, info=info) async def __call__(self, ctx: Context) -> None: - return await super().__call__(ctx) + if await self.start_condition.is_true(ctx): + return await super().__call__(ctx) + else: + return logger.debug(f"{self.__class__.__name__} not called: self.start_condition = {self.start_condition}") class BasePriority(BaseScriptFunc, ABC): diff --git a/tests/core/test_processing.py b/tests/core/test_processing.py index 0c3e7c509..0b7ba848a 100644 --- a/tests/core/test_processing.py +++ b/tests/core/test_processing.py @@ -1,4 +1,4 @@ -from chatsky import proc, Context, BaseResponse, MessageInitTypes, Message +from chatsky import proc, Context, BaseResponse, MessageInitTypes, Message, BaseProcessing from chatsky.core.script import Node @@ -22,3 +22,35 @@ async def modified_response(self, original_response: BaseResponse, ctx: Context) assert ctx.current_node.response.__class__.__name__ == "ModifiedResponse" assert await ctx.current_node.response(ctx) == Message(misc={"msg": Message("hi")}) + + +class TestConditionalResponce: + async def test_conditional_response(self): + ctx = Context() + ctx.framework_data.current_node = Node() + some_list = [] + + class SomeProcessing(BaseProcessing): + async def call(self, ctx: Context): + some_list.append("") + + await SomeProcessing()(ctx) + assert some_list == [""] + + await SomeProcessing().wrapped_call(ctx) + assert some_list == ["", ""] + + async def test_conditional_processing_false_condition(self): + ctx = Context() + ctx.framework_data.current_node = Node() + some_list = [] + + class SomeProcessing(BaseProcessing): + async def call(self, ctx: Context): + some_list.append("") + + await SomeProcessing(start_condition=False)(ctx) + assert some_list == [] + + await SomeProcessing(start_condition=False).wrapped_call(ctx) + assert some_list == [] diff --git a/tests/core/test_script_function.py b/tests/core/test_script_function.py index 0d23126a1..c099b86b6 100644 --- a/tests/core/test_script_function.py +++ b/tests/core/test_script_function.py @@ -47,8 +47,8 @@ async def call(self, ctx): raise RuntimeError() assert isinstance(await MyProc().wrapped_call(None), RuntimeError) - assert len(log_list) == 1 - assert log_list[0].levelname == "ERROR" + assert len(log_list) == 2 + assert log_list[1].levelname == "ERROR" async def test_base_exception_not_handled(self): class SpecialException(BaseException):