Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 修复 BUG #101

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions amiyabot/builtin/message/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,23 @@ def __init__(self, result: bool, weight: Union[int, float] = 0, keypoint: Option
self.weight = weight
self.keypoint = keypoint

self.on_selected: Optional[Callable] = None

def __bool__(self):
return bool(self.result)

def __repr__(self):
return f'<Verify, {self.result}, {self.weight}>'

def set_attrs(self, *attrs: Any):
indexes = [
'result',
'weight',
'keypoint',
]
for index, value in zip(indexes, attrs):
setattr(self, index, value)


@dataclass
class File:
Expand Down
73 changes: 49 additions & 24 deletions amiyabot/factory/implemented.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from typing import List
from contextlib import contextmanager
from dataclasses import dataclass
from amiyabot.util import remove_prefix_once
from amiyabot.builtin.message import Message, MessageMatch, Verify, Equal
Expand All @@ -8,7 +10,7 @@

@dataclass
class MessageHandlerItemImpl(MessageHandlerItem):
def __check(self, data: Message, obj: KeywordsType) -> Verify:
def __check(self, result: Verify, data: Message, obj: KeywordsType):
methods = {
str: MessageMatch.check_str,
Equal: MessageMatch.check_equal,
Expand All @@ -18,33 +20,54 @@ def __check(self, data: Message, obj: KeywordsType) -> Verify:

if t in methods:
method = methods[t]
check = Verify(*method(data, obj, self.level))
check = result.set_attrs(*method(data, obj, self.level))
if check:
return check

elif t is list:
for item in obj:
check = self.__check(data, item)
check = self.__check(result, data, item)
if check:
return check

return Verify(False)
return result

@classmethod
def update_data(cls, data: Message, prefix_keywords: List[str]):
def func():
text, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
data.text_prefix = prefix
data.set_text(text, set_original=False)

return func

@classmethod
@contextmanager
def restore_data(cls, result: Verify, data: Message):
if result.on_selected:
result.on_selected()
yield
data.text_prefix = ''
data.set_text(data.text_original, set_original=False)

async def verify(self, data: Message):
result = Verify(False)

# 检查是否支持私信
direct_only = self.direct_only or (self.group_config and self.group_config.direct_only)

if data.is_direct:
if not direct_only:
if self.allow_direct is None:
if not self.group_config or not self.group_config.allow_direct:
return Verify(False)
return result

if self.allow_direct is False:
return Verify(False)
return result
else:
if direct_only:
return Verify(False)
return result

# 检查是否包含前缀触发词或被 @
flag = False
Expand All @@ -62,13 +85,11 @@ async def verify(self, data: Message):

if not prefix_keywords:
flag = True

# 如果前缀校验通过,再次修正 Message 对象的属性值
text, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
flag = True
data.text_prefix = prefix
data.set_text(text, set_original=False)
else:
_, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
flag = True
result.on_selected = self.update_data(data, prefix_keywords)

# 若不通过以上检查,且关键字不为全等句式(Equal)
# 则允许当关键字为列表时,筛选列表内的全等句式继续执行校验,否则校验不通过
Expand All @@ -77,23 +98,27 @@ async def verify(self, data: Message):
if equal_filter:
self.keywords = equal_filter
else:
return Verify(False)
return result

# 执行自定义校验并修正其返回值
if self.custom_verify:
result = await self.custom_verify(data)
with self.restore_data(result, data):
res = await self.custom_verify(data)

if isinstance(res, bool) or res is None:
result.result = bool(res)
result.weight = int(bool(res))

if isinstance(result, bool) or result is None:
result = result, int(bool(result)), None
elif isinstance(res, tuple):
contrast = bool(res[0]), int(bool(res[0])), None
res_len = len(res)
res = (res + contrast[res_len:])[:3]

elif isinstance(result, tuple):
contrast = bool(result[0]), int(bool(result[0])), None
result_len = len(result)
result = (result + contrast[result_len:])[:3]
result.set_attrs(*res)

return Verify(*result)
return result

return self.__check(data, self.keywords)
return self.__check(result, data, self.keywords)

async def action(self, data: Message):
return await self.function(data)
2 changes: 2 additions & 0 deletions amiyabot/handler/messageHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ async def choice_handlers(data: Message, handlers: List[MessageHandlerItem], wai

# 将 Verify 结果赋值给 Message
data.verify = selected[0]
if data.verify.on_selected:
data.verify.on_selected()

return selected[1]

Expand Down
Loading