Skip to content

Commit

Permalink
🔒 fix injection problem for MsgTemplate
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Jan 1, 2024
1 parent 84f8169 commit e81ca45
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
49 changes: 47 additions & 2 deletions src/nonebot_plugin_alconna/uniseg/template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import functools
import _string # type: ignore
from string import Formatter
from typing_extensions import TypeAlias
from typing import (
Expand Down Expand Up @@ -102,6 +103,43 @@ def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
# f"not all arguments converted during string formatting: " f"{set(kwargs) - set(keys)}"
# )

def _getattr(self, route: str, obj: Any):
res = obj
parts = re.split(r"\.|(\[.+\])|(\(.*\))", route)
for part in parts[1:]:
if not part:
continue
if part.startswith("_"):
raise ValueError(route)
if part.startswith("[") and part.endswith("]"):
item = part[1:-1]
if item[0] in ("'", '"') and item[-1] in ("'", '"'):
res = res[item[1:-1]]
elif ":" in item:
res = res[slice(*map(lambda x: int(x) if x else None, item.split(':')))]
else:
res = res[int(item)]
elif part.startswith("(") and part.endswith(")"):
item = part[1:-1]
if not item:
res = res()
else:
_parts = item.split(",")
_args = []
_kwargs = {}
for part in _parts:
part = part.strip()
if re.match(".+=.+", part):
k, v = part.split("=")
_kwargs[k] = v
else:
_args.append(part)
res = res(*_args, **_kwargs)
else:
res = getattr(res, part)
return res


def _vformat(
self,
format_string: str,
Expand All @@ -126,14 +164,14 @@ def _vformat(
for part in parts:
part = part.strip()
if part.startswith("$") and (key := part.split(".")[0]) in kwargs:
_args.append(eval(part[1:], {}, {key[1:]: kwargs[key]}))
_args.append(self._getattr(part[1:], kwargs[key]))
elif re.match(".+=.+", part):
k, v = part.split("=")
if v in kwargs:
_kwargs[k] = kwargs[v]
used_args.add(v)
elif v.startswith("$") and (key := v.split(".")[0]) in kwargs:
_kwargs[k] = eval(v[1:], {}, {key[1:]: kwargs[key]})
_kwargs[k] = self._getattr(v[1:], kwargs[key])
else:
_kwargs[k] = v
elif part in kwargs:
Expand Down Expand Up @@ -179,6 +217,13 @@ def format_field(self, value: Any, format_spec: str) -> Any:
formatter = _MAPPING[format_spec] # type: ignore
return super().format_field(value, format_spec) if formatter is None else formatter(value)

def get_field(self, field_name, args, kwargs):
first, rest = _string.formatter_field_name_split(field_name)

obj = self.get_value(first, args, kwargs)

return obj, first

def _add(self, a: Any, b: Any) -> Any:
try:
return a + b
Expand Down
7 changes: 4 additions & 3 deletions tests/test_uniseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ async def test_unimsg_template(app: App):

@matcher.handle()
async def handle():
await matcher.finish(UniMessage.template("{:Reply($message_id)}{:At(user, $event.get_user_id())}"))
await matcher.finish(UniMessage.template("{:Reply($message_id)}{:At(user, $event.get_user_id()[1:])}"))

async with app.test_matcher(matcher) as ctx:
adapter = get_adapter(Adapter)
bot = ctx.create_bot(base=Bot, adapter=adapter)
event = fake_group_message_event_v11(message=Message("test_unimsg_template"), user_id=123)
ctx.receive_event(bot, event)
ctx.should_call_send(event, MessageSegment.reply(1) + MessageSegment.at(123))
ctx.should_call_send(event, MessageSegment.reply(1) + MessageSegment.at(23))
ctx.should_finished(matcher)


Expand Down Expand Up @@ -108,4 +108,5 @@ async def handle(msg: MsgId):
"message": Message("hello!"),
},
)
await Target("456", platform=adapter.get_name()).send("hello!")
target = Target("456", platform=adapter.get_name())
await target.send("hello!")

0 comments on commit e81ca45

Please sign in to comment.