diff --git a/amiyabot/builtin/message/structure.py b/amiyabot/builtin/message/structure.py index 863d9d3..cf67ff6 100644 --- a/amiyabot/builtin/message/structure.py +++ b/amiyabot/builtin/message/structure.py @@ -138,12 +138,23 @@ def __init__(self, result: bool, weight: Union[int, float] = 0, keypoint: Option self.weight = weight self.keypoint = keypoint + self.on_selected: Optional[Callable] = None + def __bool__(self): return bool(self.result) def __repr__(self): return f'' + def set_attrs(self, *attrs: Any): + indexes = [ + 'result', + 'weight', + 'keypoint', + ] + for index, value in zip(indexes, attrs): + setattr(self, index, value) + @dataclass class File: diff --git a/amiyabot/factory/implemented.py b/amiyabot/factory/implemented.py index 3ac3821..2369300 100644 --- a/amiyabot/factory/implemented.py +++ b/amiyabot/factory/implemented.py @@ -1,5 +1,7 @@ import re +from typing import List +from contextlib import contextmanager from dataclasses import dataclass from amiyabot.util import remove_prefix_once from amiyabot.builtin.message import Message, MessageMatch, Verify, Equal @@ -8,7 +10,7 @@ @dataclass class MessageHandlerItemImpl(MessageHandlerItem): - def __check(self, data: Message, obj: KeywordsType) -> Verify: + def __check(self, result: Verify, data: Message, obj: KeywordsType): methods = { str: MessageMatch.check_str, Equal: MessageMatch.check_equal, @@ -18,19 +20,40 @@ def __check(self, data: Message, obj: KeywordsType) -> Verify: if t in methods: method = methods[t] - check = Verify(*method(data, obj, self.level)) + check = result.set_attrs(*method(data, obj, self.level)) if check: return check elif t is list: for item in obj: - check = self.__check(data, item) + check = self.__check(result, data, item) if check: return check - return Verify(False) + return result + + @classmethod + def update_data(cls, data: Message, prefix_keywords: List[str]): + def func(): + text, prefix = remove_prefix_once(data.text, prefix_keywords) + if prefix: + data.text_prefix = prefix + data.set_text(text, set_original=False) + + return func + + @classmethod + @contextmanager + def restore_data(cls, result: Verify, data: Message): + if result.on_selected: + result.on_selected() + yield + data.text_prefix = '' + data.set_text(data.text_original, set_original=False) async def verify(self, data: Message): + result = Verify(False) + # 检查是否支持私信 direct_only = self.direct_only or (self.group_config and self.group_config.direct_only) @@ -38,13 +61,13 @@ async def verify(self, data: Message): if not direct_only: if self.allow_direct is None: if not self.group_config or not self.group_config.allow_direct: - return Verify(False) + return result if self.allow_direct is False: - return Verify(False) + return result else: if direct_only: - return Verify(False) + return result # 检查是否包含前缀触发词或被 @ flag = False @@ -62,13 +85,11 @@ async def verify(self, data: Message): if not prefix_keywords: flag = True - - # 如果前缀校验通过,再次修正 Message 对象的属性值 - text, prefix = remove_prefix_once(data.text, prefix_keywords) - if prefix: - flag = True - data.text_prefix = prefix - data.set_text(text, set_original=False) + else: + _, prefix = remove_prefix_once(data.text, prefix_keywords) + if prefix: + flag = True + result.on_selected = self.update_data(data, prefix_keywords) # 若不通过以上检查,且关键字不为全等句式(Equal) # 则允许当关键字为列表时,筛选列表内的全等句式继续执行校验,否则校验不通过 @@ -77,23 +98,27 @@ async def verify(self, data: Message): if equal_filter: self.keywords = equal_filter else: - return Verify(False) + return result # 执行自定义校验并修正其返回值 if self.custom_verify: - result = await self.custom_verify(data) + with self.restore_data(result, data): + res = await self.custom_verify(data) + + if isinstance(res, bool) or res is None: + result.result = bool(res) + result.weight = int(bool(res)) - if isinstance(result, bool) or result is None: - result = result, int(bool(result)), None + elif isinstance(res, tuple): + contrast = bool(res[0]), int(bool(res[0])), None + res_len = len(res) + res = (res + contrast[res_len:])[:3] - elif isinstance(result, tuple): - contrast = bool(result[0]), int(bool(result[0])), None - result_len = len(result) - result = (result + contrast[result_len:])[:3] + result.set_attrs(*res) - return Verify(*result) + return result - return self.__check(data, self.keywords) + return self.__check(result, data, self.keywords) async def action(self, data: Message): return await self.function(data) diff --git a/amiyabot/handler/messageHandler.py b/amiyabot/handler/messageHandler.py index e05cd78..3a26bdc 100644 --- a/amiyabot/handler/messageHandler.py +++ b/amiyabot/handler/messageHandler.py @@ -150,6 +150,8 @@ async def choice_handlers(data: Message, handlers: List[MessageHandlerItem], wai # 将 Verify 结果赋值给 Message data.verify = selected[0] + if data.verify.on_selected: + data.verify.on_selected() return selected[1]