From 32bc2c314ab0ef2d58db5df22985be90bb372f59 Mon Sep 17 00:00:00 2001
From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Date: Thu, 5 Dec 2024 20:55:24 +0800
Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E5=AD=98=E5=82=A8=20ma?=
=?UTF-8?q?tcher=20=E5=8F=91=E9=80=81=20prompt=20=E7=9A=84=E7=BB=93?=
=?UTF-8?q?=E6=9E=9C=20(#3155)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
nonebot/consts.py | 4 ++
nonebot/internal/matcher/matcher.py | 40 +++++++++--
nonebot/internal/params.py | 43 ++++++++---
nonebot/params.py | 25 +++++++
tests/plugins/param/param_arg.py | 8 ++-
tests/plugins/param/param_matcher.py | 17 ++++-
tests/test_param.py | 53 +++++++++++++-
website/docs/advanced/dependency.mdx | 103 +++++++++++++++++++++++++++
8 files changed, 271 insertions(+), 22 deletions(-)
diff --git a/nonebot/consts.py b/nonebot/consts.py
index 701307d3..0cf4056f 100644
--- a/nonebot/consts.py
+++ b/nonebot/consts.py
@@ -22,6 +22,10 @@ REJECT_TARGET: Literal["_current_target"] = "_current_target"
"""当前 `reject` 目标存储 key"""
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
"""下一个 `reject` 目标存储 key"""
+PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result"
+"""`pause` prompt 发送结果存储 key"""
+REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
+"""`reject` prompt 发送结果存储 key"""
# used by Rule
PREFIX_KEY: Literal["_prefix"] = "_prefix"
diff --git a/nonebot/internal/matcher/matcher.py b/nonebot/internal/matcher/matcher.py
index 7f18effa..164b7b89 100644
--- a/nonebot/internal/matcher/matcher.py
+++ b/nonebot/internal/matcher/matcher.py
@@ -27,8 +27,10 @@ from exceptiongroup import BaseExceptionGroup, catch
from nonebot.consts import (
ARG_KEY,
LAST_RECEIVE_KEY,
+ PAUSE_PROMPT_RESULT_KEY,
RECEIVE_KEY,
REJECT_CACHE_TARGET,
+ REJECT_PROMPT_RESULT_KEY,
REJECT_TARGET,
)
from nonebot.dependencies import Dependent, Param
@@ -560,8 +562,8 @@ class Matcher(metaclass=MatcherMeta):
"""
bot = current_bot.get()
event = current_event.get()
- state = current_matcher.get().state
if isinstance(message, MessageTemplate):
+ state = current_matcher.get().state
_message = message.format(**state)
else:
_message = message
@@ -597,8 +599,15 @@ class Matcher(metaclass=MatcherMeta):
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
+ try:
+ matcher = current_matcher.get()
+ except Exception:
+ matcher = None
+
if prompt is not None:
- await cls.send(prompt, **kwargs)
+ result = await cls.send(prompt, **kwargs)
+ if matcher is not None:
+ matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
raise PausedException
@classmethod
@@ -615,8 +624,19 @@ class Matcher(metaclass=MatcherMeta):
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
+ try:
+ matcher = current_matcher.get()
+ key = matcher.get_target()
+ except Exception:
+ matcher = None
+ key = None
+
+ key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None
+
if prompt is not None:
- await cls.send(prompt, **kwargs)
+ result = await cls.send(prompt, **kwargs)
+ if key is not None and matcher:
+ matcher.state[key] = result
raise RejectedException
@classmethod
@@ -636,9 +656,12 @@ class Matcher(metaclass=MatcherMeta):
请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
- matcher.set_target(ARG_KEY.format(key=key))
+ arg_key = ARG_KEY.format(key=key)
+ matcher.set_target(arg_key)
+
if prompt is not None:
- await cls.send(prompt, **kwargs)
+ result = await cls.send(prompt, **kwargs)
+ matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result
raise RejectedException
@classmethod
@@ -658,9 +681,12 @@ class Matcher(metaclass=MatcherMeta):
请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
- matcher.set_target(RECEIVE_KEY.format(id=id))
+ receive_key = RECEIVE_KEY.format(id=id)
+ matcher.set_target(receive_key)
+
if prompt is not None:
- await cls.send(prompt, **kwargs)
+ result = await cls.send(prompt, **kwargs)
+ matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result
raise RejectedException
@classmethod
diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py
index 89d11990..86f776d6 100644
--- a/nonebot/internal/params.py
+++ b/nonebot/internal/params.py
@@ -18,6 +18,7 @@ from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
+from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
from nonebot.dependencies import Dependent, Param
from nonebot.dependencies.utils import check_field_type
from nonebot.exception import SkippedException
@@ -39,7 +40,7 @@ from nonebot.utils import (
)
if TYPE_CHECKING:
- from nonebot.adapters import Bot, Event
+ from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher
@@ -522,10 +523,10 @@ class MatcherParam(Param):
class ArgInner:
def __init__(
- self, key: Optional[str], type: Literal["message", "str", "plaintext"]
+ self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
) -> None:
self.key: Optional[str] = key
- self.type: Literal["message", "str", "plaintext"] = type
+ self.type: Literal["message", "str", "plaintext", "prompt"] = type
def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
@@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "plaintext") # type: ignore
+def ArgPromptResult(key: Optional[str] = None) -> Any:
+ """`arg` prompt 发送结果"""
+ return ArgInner(key, "prompt")
+
+
class ArgParam(Param):
"""Arg 注入参数
@@ -559,7 +565,7 @@ class ArgParam(Param):
self,
*args,
key: str,
- type: Literal["message", "str", "plaintext"],
+ type: Literal["message", "str", "plaintext", "prompt"],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
@@ -584,15 +590,32 @@ class ArgParam(Param):
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
self, matcher: "Matcher", **kwargs: Any
) -> Any:
- message = matcher.get_arg(self.key)
- if message is None:
- return message
if self.type == "message":
- return message
+ return self._solve_message(matcher)
elif self.type == "str":
- return str(message)
+ return self._solve_str(matcher)
+ elif self.type == "plaintext":
+ return self._solve_plaintext(matcher)
+ elif self.type == "prompt":
+ return self._solve_prompt(matcher)
else:
- return message.extract_plain_text()
+ raise ValueError(f"Unknown Arg type: {self.type}")
+
+ def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
+ return matcher.get_arg(self.key)
+
+ def _solve_str(self, matcher: "Matcher") -> Optional[str]:
+ message = matcher.get_arg(self.key)
+ return str(message) if message is not None else None
+
+ def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
+ message = matcher.get_arg(self.key)
+ return message.extract_plain_text() if message is not None else None
+
+ def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
+ return matcher.state.get(
+ REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
+ )
class ExceptionParam(Param):
diff --git a/nonebot/params.py b/nonebot/params.py
index a4400e7d..c2501051 100644
--- a/nonebot/params.py
+++ b/nonebot/params.py
@@ -19,9 +19,12 @@ from nonebot.consts import (
ENDSWITH_KEY,
FULLMATCH_KEY,
KEYWORD_KEY,
+ PAUSE_PROMPT_RESULT_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
+ RECEIVE_KEY,
REGEX_MATCHED,
+ REJECT_PROMPT_RESULT_KEY,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
@@ -29,6 +32,7 @@ from nonebot.consts import (
from nonebot.internal.params import Arg as Arg
from nonebot.internal.params import ArgParam as ArgParam
from nonebot.internal.params import ArgPlainText as ArgPlainText
+from nonebot.internal.params import ArgPromptResult as ArgPromptResult
from nonebot.internal.params import ArgStr as ArgStr
from nonebot.internal.params import BotParam as BotParam
from nonebot.internal.params import DefaultParam as DefaultParam
@@ -252,6 +256,26 @@ def LastReceived(default: Any = None) -> Any:
return Depends(_last_received, use_cache=False)
+def ReceivePromptResult(id: Optional[str] = None) -> Any:
+ """`receive` prompt 发送结果"""
+
+ def _receive_prompt_result(matcher: "Matcher") -> Any:
+ return matcher.state.get(
+ REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
+ )
+
+ return Depends(_receive_prompt_result, use_cache=False)
+
+
+def PausePromptResult() -> Any:
+ """`pause` prompt 发送结果"""
+
+ def _pause_prompt_result(matcher: "Matcher") -> Any:
+ return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)
+
+ return Depends(_pause_prompt_result, use_cache=False)
+
+
__autodoc__ = {
"Arg": True,
"ArgStr": True,
@@ -265,4 +289,5 @@ __autodoc__ = {
"DefaultParam": True,
"MatcherParam": True,
"ExceptionParam": True,
+ "ArgPromptResult": True,
}
diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py
index 6bf64ded..c807228c 100644
--- a/tests/plugins/param/param_arg.py
+++ b/tests/plugins/param/param_arg.py
@@ -1,7 +1,7 @@
-from typing import Annotated
+from typing import Annotated, Any
from nonebot.adapters import Message
-from nonebot.params import Arg, ArgPlainText, ArgStr
+from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr
async def arg(key: Message = Arg()) -> Message:
@@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key
+async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
+ return key
+
+
# test dependency priority
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
return key
diff --git a/tests/plugins/param/param_matcher.py b/tests/plugins/param/param_matcher.py
index 6e8ec2fc..dd90b1f6 100644
--- a/tests/plugins/param/param_matcher.py
+++ b/tests/plugins/param/param_matcher.py
@@ -1,8 +1,13 @@
-from typing import TypeVar, Union
+from typing import Any, TypeVar, Union
from nonebot.adapters import Event
from nonebot.matcher import Matcher
-from nonebot.params import LastReceived, Received
+from nonebot.params import (
+ LastReceived,
+ PausePromptResult,
+ Received,
+ ReceivePromptResult,
+)
async def matcher(m: Matcher) -> Matcher:
@@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:
async def last_receive(e: Event = LastReceived()) -> Event:
return e
+
+
+async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
+ return result
+
+
+async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
+ return result
diff --git a/tests/test_param.py b/tests/test_param.py
index c2001d56..bf561c9e 100644
--- a/tests/test_param.py
+++ b/tests/test_param.py
@@ -1,3 +1,4 @@
+from contextlib import suppress
import re
from exceptiongroup import BaseExceptionGroup
@@ -5,6 +6,7 @@ from nonebug import App
import pytest
from nonebot.consts import (
+ ARG_KEY,
CMD_ARG_KEY,
CMD_KEY,
CMD_START_KEY,
@@ -14,13 +16,14 @@ from nonebot.consts import (
KEYWORD_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
+ RECEIVE_KEY,
REGEX_MATCHED,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
)
from nonebot.dependencies import Dependent
-from nonebot.exception import TypeMisMatch
+from nonebot.exception import PausedException, RejectedException, TypeMisMatch
from nonebot.matcher import Matcher
from nonebot.params import (
ArgParam,
@@ -469,8 +472,10 @@ async def test_matcher(app: App):
matcher,
not_legacy_matcher,
not_matcher,
+ pause_prompt_result,
postpone_matcher,
receive,
+ receive_prompt_result,
sub_matcher,
union_matcher,
)
@@ -538,12 +543,42 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(event_next)
+ fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False)
+
+ async with app.test_api() as ctx:
+ bot = ctx.create_bot()
+ ctx.should_call_send(event, "test", result=True, bot=bot)
+ with fake_matcher.ensure_context(bot, event):
+ with suppress(RejectedException):
+ await fake_matcher.reject("test")
+
+ async with app.test_dependent(
+ receive_prompt_result, allow_types=[MatcherParam, DependParam]
+ ) as ctx:
+ ctx.pass_params(matcher=fake_matcher)
+ ctx.should_return(True)
+
+ async with app.test_api() as ctx:
+ bot = ctx.create_bot()
+ ctx.should_call_send(event, "test", result=False, bot=bot)
+ with fake_matcher.ensure_context(bot, event):
+ fake_matcher.set_target("test")
+ with suppress(PausedException):
+ await fake_matcher.pause("test")
+
+ async with app.test_dependent(
+ pause_prompt_result, allow_types=[MatcherParam, DependParam]
+ ) as ctx:
+ ctx.pass_params(matcher=fake_matcher)
+ ctx.should_return(False)
+
@pytest.mark.anyio
async def test_arg(app: App):
from plugins.param.param_arg import (
annotated_arg,
annotated_arg_plain_text,
+ annotated_arg_prompt_result,
annotated_arg_str,
annotated_multi_arg,
annotated_prior_arg,
@@ -553,6 +588,7 @@ async def test_arg(app: App):
)
matcher = Matcher()
+ event = make_fake_event()()
message = FakeMessage("text")
matcher.set_arg("key", message)
@@ -582,6 +618,21 @@ async def test_arg(app: App):
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
+ matcher.set_target(ARG_KEY.format(key="key"), cache=False)
+
+ async with app.test_api() as ctx:
+ bot = ctx.create_bot()
+ ctx.should_call_send(event, "test", result="arg", bot=bot)
+ with matcher.ensure_context(bot, event):
+ with suppress(RejectedException):
+ await matcher.reject("test")
+
+ async with app.test_dependent(
+ annotated_arg_prompt_result, allow_types=[ArgParam]
+ ) as ctx:
+ ctx.pass_params(matcher=matcher)
+ ctx.should_return("arg")
+
async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
diff --git a/website/docs/advanced/dependency.mdx b/website/docs/advanced/dependency.mdx
index 3efc35cd..e7946b39 100644
--- a/website/docs/advanced/dependency.mdx
+++ b/website/docs/advanced/dependency.mdx
@@ -1224,6 +1224,37 @@ async def _(foo: Event = LastReceived()): ...
+### ReceivePromptResult
+
+获取某次 `receive` 发送提示消息的结果。
+
+
+
+
+```python {6}
+from typing import Any, Annotated
+
+from nonebot.params import ReceivePromptResult
+
+@matcher.receive("id", prompt="prompt")
+async def _(result: Annotated[Any, ReceivePromptResult("id")]): ...
+```
+
+
+
+
+```python {6}
+from typing import Any
+
+from nonebot.params import ReceivePromptResult
+
+@matcher.receive("id", prompt="prompt")
+async def _(result: Any = ReceivePromptResult("id")): ...
+```
+
+
+
+
### Arg
获取某次 `got` 接收的参数。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。
@@ -1318,3 +1349,75 @@ async def _(foo: str = ArgPlainText("key")): ...
+
+### ArgPromptResult
+
+获取某次 `got` 发送提示消息的结果。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。
+
+
+
+
+```python {6,7}
+from typing import Any, Annotated
+
+from nonebot.params import ArgPromptResult
+
+@matcher.got("key", prompt="prompt")
+async def _(result: Annotated[Any, ArgPromptResult()]): ...
+async def _(result: Annotated[Any, ArgPromptResult("key")]): ...
+```
+
+
+
+
+```python {6,7}
+from typing import Any
+
+from nonebot.params import ArgPromptResult
+
+@matcher.got("key", prompt="prompt")
+async def _(result: Any = ArgPromptResult()): ...
+async def _(result: Any = ArgPromptResult("key")): ...
+```
+
+
+
+
+### PausePromptResult
+
+获取最近一次 `pause` 发送提示消息的结果。
+
+
+
+
+```python {6}
+from typing import Any, Annotated
+
+from nonebot.params import PausePromptResult
+
+@matcher.handle()
+async def _():
+ await matcher.pause(prompt="prompt")
+
+@matcher.handle()
+async def _(result: Annotated[Any, PausePromptResult()]): ...
+```
+
+
+
+
+```python {6}
+from typing import Any
+
+from nonebot.params import PausePromptResult
+
+@matcher.handle()
+async def _():
+ await matcher.pause(prompt="prompt")
+
+@matcher.handle()
+async def _(result: Any = PausePromptResult()): ...
+```
+
+
+