mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-07 03:06:58 +08:00
✨ Feature: 存储 matcher 发送 prompt 的结果 (#3155)
This commit is contained in:
parent
ab8dea5a02
commit
32bc2c314a
@ -22,6 +22,10 @@ REJECT_TARGET: Literal["_current_target"] = "_current_target"
|
||||
"""当前 `reject` 目标存储 key"""
|
||||
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
|
||||
"""下一个 `reject` 目标存储 key"""
|
||||
PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result"
|
||||
"""`pause` prompt 发送结果存储 key"""
|
||||
REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
|
||||
"""`reject` prompt 发送结果存储 key"""
|
||||
|
||||
# used by Rule
|
||||
PREFIX_KEY: Literal["_prefix"] = "_prefix"
|
||||
|
@ -27,8 +27,10 @@ from exceptiongroup import BaseExceptionGroup, catch
|
||||
from nonebot.consts import (
|
||||
ARG_KEY,
|
||||
LAST_RECEIVE_KEY,
|
||||
PAUSE_PROMPT_RESULT_KEY,
|
||||
RECEIVE_KEY,
|
||||
REJECT_CACHE_TARGET,
|
||||
REJECT_PROMPT_RESULT_KEY,
|
||||
REJECT_TARGET,
|
||||
)
|
||||
from nonebot.dependencies import Dependent, Param
|
||||
@ -560,8 +562,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
"""
|
||||
bot = current_bot.get()
|
||||
event = current_event.get()
|
||||
state = current_matcher.get().state
|
||||
if isinstance(message, MessageTemplate):
|
||||
state = current_matcher.get().state
|
||||
_message = message.format(**state)
|
||||
else:
|
||||
_message = message
|
||||
@ -597,8 +599,15 @@ class Matcher(metaclass=MatcherMeta):
|
||||
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
|
||||
请参考对应 adapter 的 bot 对象 api
|
||||
"""
|
||||
try:
|
||||
matcher = current_matcher.get()
|
||||
except Exception:
|
||||
matcher = None
|
||||
|
||||
if prompt is not None:
|
||||
await cls.send(prompt, **kwargs)
|
||||
result = await cls.send(prompt, **kwargs)
|
||||
if matcher is not None:
|
||||
matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
|
||||
raise PausedException
|
||||
|
||||
@classmethod
|
||||
@ -615,8 +624,19 @@ class Matcher(metaclass=MatcherMeta):
|
||||
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
|
||||
请参考对应 adapter 的 bot 对象 api
|
||||
"""
|
||||
try:
|
||||
matcher = current_matcher.get()
|
||||
key = matcher.get_target()
|
||||
except Exception:
|
||||
matcher = None
|
||||
key = None
|
||||
|
||||
key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None
|
||||
|
||||
if prompt is not None:
|
||||
await cls.send(prompt, **kwargs)
|
||||
result = await cls.send(prompt, **kwargs)
|
||||
if key is not None and matcher:
|
||||
matcher.state[key] = result
|
||||
raise RejectedException
|
||||
|
||||
@classmethod
|
||||
@ -636,9 +656,12 @@ class Matcher(metaclass=MatcherMeta):
|
||||
请参考对应 adapter 的 bot 对象 api
|
||||
"""
|
||||
matcher = current_matcher.get()
|
||||
matcher.set_target(ARG_KEY.format(key=key))
|
||||
arg_key = ARG_KEY.format(key=key)
|
||||
matcher.set_target(arg_key)
|
||||
|
||||
if prompt is not None:
|
||||
await cls.send(prompt, **kwargs)
|
||||
result = await cls.send(prompt, **kwargs)
|
||||
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result
|
||||
raise RejectedException
|
||||
|
||||
@classmethod
|
||||
@ -658,9 +681,12 @@ class Matcher(metaclass=MatcherMeta):
|
||||
请参考对应 adapter 的 bot 对象 api
|
||||
"""
|
||||
matcher = current_matcher.get()
|
||||
matcher.set_target(RECEIVE_KEY.format(id=id))
|
||||
receive_key = RECEIVE_KEY.format(id=id)
|
||||
matcher.set_target(receive_key)
|
||||
|
||||
if prompt is not None:
|
||||
await cls.send(prompt, **kwargs)
|
||||
result = await cls.send(prompt, **kwargs)
|
||||
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result
|
||||
raise RejectedException
|
||||
|
||||
@classmethod
|
||||
|
@ -18,6 +18,7 @@ from exceptiongroup import BaseExceptionGroup, catch
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
||||
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
|
||||
from nonebot.dependencies import Dependent, Param
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.exception import SkippedException
|
||||
@ -39,7 +40,7 @@ from nonebot.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
from nonebot.matcher import Matcher
|
||||
|
||||
|
||||
@ -522,10 +523,10 @@ class MatcherParam(Param):
|
||||
|
||||
class ArgInner:
|
||||
def __init__(
|
||||
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
|
||||
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
|
||||
) -> None:
|
||||
self.key: Optional[str] = key
|
||||
self.type: Literal["message", "str", "plaintext"] = type
|
||||
self.type: Literal["message", "str", "plaintext", "prompt"] = type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ArgInner(key={self.key!r}, type={self.type!r})"
|
||||
@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
|
||||
return ArgInner(key, "plaintext") # type: ignore
|
||||
|
||||
|
||||
def ArgPromptResult(key: Optional[str] = None) -> Any:
|
||||
"""`arg` prompt 发送结果"""
|
||||
return ArgInner(key, "prompt")
|
||||
|
||||
|
||||
class ArgParam(Param):
|
||||
"""Arg 注入参数
|
||||
|
||||
@ -559,7 +565,7 @@ class ArgParam(Param):
|
||||
self,
|
||||
*args,
|
||||
key: str,
|
||||
type: Literal["message", "str", "plaintext"],
|
||||
type: Literal["message", "str", "plaintext", "prompt"],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -584,15 +590,32 @@ class ArgParam(Param):
|
||||
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
self, matcher: "Matcher", **kwargs: Any
|
||||
) -> Any:
|
||||
message = matcher.get_arg(self.key)
|
||||
if message is None:
|
||||
return message
|
||||
if self.type == "message":
|
||||
return message
|
||||
return self._solve_message(matcher)
|
||||
elif self.type == "str":
|
||||
return str(message)
|
||||
return self._solve_str(matcher)
|
||||
elif self.type == "plaintext":
|
||||
return self._solve_plaintext(matcher)
|
||||
elif self.type == "prompt":
|
||||
return self._solve_prompt(matcher)
|
||||
else:
|
||||
return message.extract_plain_text()
|
||||
raise ValueError(f"Unknown Arg type: {self.type}")
|
||||
|
||||
def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
|
||||
return matcher.get_arg(self.key)
|
||||
|
||||
def _solve_str(self, matcher: "Matcher") -> Optional[str]:
|
||||
message = matcher.get_arg(self.key)
|
||||
return str(message) if message is not None else None
|
||||
|
||||
def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
|
||||
message = matcher.get_arg(self.key)
|
||||
return message.extract_plain_text() if message is not None else None
|
||||
|
||||
def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
|
||||
return matcher.state.get(
|
||||
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
|
||||
)
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
|
@ -19,9 +19,12 @@ from nonebot.consts import (
|
||||
ENDSWITH_KEY,
|
||||
FULLMATCH_KEY,
|
||||
KEYWORD_KEY,
|
||||
PAUSE_PROMPT_RESULT_KEY,
|
||||
PREFIX_KEY,
|
||||
RAW_CMD_KEY,
|
||||
RECEIVE_KEY,
|
||||
REGEX_MATCHED,
|
||||
REJECT_PROMPT_RESULT_KEY,
|
||||
SHELL_ARGS,
|
||||
SHELL_ARGV,
|
||||
STARTSWITH_KEY,
|
||||
@ -29,6 +32,7 @@ from nonebot.consts import (
|
||||
from nonebot.internal.params import Arg as Arg
|
||||
from nonebot.internal.params import ArgParam as ArgParam
|
||||
from nonebot.internal.params import ArgPlainText as ArgPlainText
|
||||
from nonebot.internal.params import ArgPromptResult as ArgPromptResult
|
||||
from nonebot.internal.params import ArgStr as ArgStr
|
||||
from nonebot.internal.params import BotParam as BotParam
|
||||
from nonebot.internal.params import DefaultParam as DefaultParam
|
||||
@ -252,6 +256,26 @@ def LastReceived(default: Any = None) -> Any:
|
||||
return Depends(_last_received, use_cache=False)
|
||||
|
||||
|
||||
def ReceivePromptResult(id: Optional[str] = None) -> Any:
|
||||
"""`receive` prompt 发送结果"""
|
||||
|
||||
def _receive_prompt_result(matcher: "Matcher") -> Any:
|
||||
return matcher.state.get(
|
||||
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
|
||||
)
|
||||
|
||||
return Depends(_receive_prompt_result, use_cache=False)
|
||||
|
||||
|
||||
def PausePromptResult() -> Any:
|
||||
"""`pause` prompt 发送结果"""
|
||||
|
||||
def _pause_prompt_result(matcher: "Matcher") -> Any:
|
||||
return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)
|
||||
|
||||
return Depends(_pause_prompt_result, use_cache=False)
|
||||
|
||||
|
||||
__autodoc__ = {
|
||||
"Arg": True,
|
||||
"ArgStr": True,
|
||||
@ -265,4 +289,5 @@ __autodoc__ = {
|
||||
"DefaultParam": True,
|
||||
"MatcherParam": True,
|
||||
"ExceptionParam": True,
|
||||
"ArgPromptResult": True,
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from nonebot.adapters import Message
|
||||
from nonebot.params import Arg, ArgPlainText, ArgStr
|
||||
from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr
|
||||
|
||||
|
||||
async def arg(key: Message = Arg()) -> Message:
|
||||
@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
|
||||
return key
|
||||
|
||||
|
||||
async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
|
||||
return key
|
||||
|
||||
|
||||
# test dependency priority
|
||||
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
|
||||
return key
|
||||
|
@ -1,8 +1,13 @@
|
||||
from typing import TypeVar, Union
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import LastReceived, Received
|
||||
from nonebot.params import (
|
||||
LastReceived,
|
||||
PausePromptResult,
|
||||
Received,
|
||||
ReceivePromptResult,
|
||||
)
|
||||
|
||||
|
||||
async def matcher(m: Matcher) -> Matcher:
|
||||
@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:
|
||||
|
||||
async def last_receive(e: Event = LastReceived()) -> Event:
|
||||
return e
|
||||
|
||||
|
||||
async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
|
||||
return result
|
||||
|
||||
|
||||
async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
|
||||
return result
|
||||
|
@ -1,3 +1,4 @@
|
||||
from contextlib import suppress
|
||||
import re
|
||||
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
@ -5,6 +6,7 @@ from nonebug import App
|
||||
import pytest
|
||||
|
||||
from nonebot.consts import (
|
||||
ARG_KEY,
|
||||
CMD_ARG_KEY,
|
||||
CMD_KEY,
|
||||
CMD_START_KEY,
|
||||
@ -14,13 +16,14 @@ from nonebot.consts import (
|
||||
KEYWORD_KEY,
|
||||
PREFIX_KEY,
|
||||
RAW_CMD_KEY,
|
||||
RECEIVE_KEY,
|
||||
REGEX_MATCHED,
|
||||
SHELL_ARGS,
|
||||
SHELL_ARGV,
|
||||
STARTSWITH_KEY,
|
||||
)
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from nonebot.exception import PausedException, RejectedException, TypeMisMatch
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import (
|
||||
ArgParam,
|
||||
@ -469,8 +472,10 @@ async def test_matcher(app: App):
|
||||
matcher,
|
||||
not_legacy_matcher,
|
||||
not_matcher,
|
||||
pause_prompt_result,
|
||||
postpone_matcher,
|
||||
receive,
|
||||
receive_prompt_result,
|
||||
sub_matcher,
|
||||
union_matcher,
|
||||
)
|
||||
@ -538,12 +543,42 @@ async def test_matcher(app: App):
|
||||
ctx.pass_params(matcher=fake_matcher)
|
||||
ctx.should_return(event_next)
|
||||
|
||||
fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False)
|
||||
|
||||
async with app.test_api() as ctx:
|
||||
bot = ctx.create_bot()
|
||||
ctx.should_call_send(event, "test", result=True, bot=bot)
|
||||
with fake_matcher.ensure_context(bot, event):
|
||||
with suppress(RejectedException):
|
||||
await fake_matcher.reject("test")
|
||||
|
||||
async with app.test_dependent(
|
||||
receive_prompt_result, allow_types=[MatcherParam, DependParam]
|
||||
) as ctx:
|
||||
ctx.pass_params(matcher=fake_matcher)
|
||||
ctx.should_return(True)
|
||||
|
||||
async with app.test_api() as ctx:
|
||||
bot = ctx.create_bot()
|
||||
ctx.should_call_send(event, "test", result=False, bot=bot)
|
||||
with fake_matcher.ensure_context(bot, event):
|
||||
fake_matcher.set_target("test")
|
||||
with suppress(PausedException):
|
||||
await fake_matcher.pause("test")
|
||||
|
||||
async with app.test_dependent(
|
||||
pause_prompt_result, allow_types=[MatcherParam, DependParam]
|
||||
) as ctx:
|
||||
ctx.pass_params(matcher=fake_matcher)
|
||||
ctx.should_return(False)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_arg(app: App):
|
||||
from plugins.param.param_arg import (
|
||||
annotated_arg,
|
||||
annotated_arg_plain_text,
|
||||
annotated_arg_prompt_result,
|
||||
annotated_arg_str,
|
||||
annotated_multi_arg,
|
||||
annotated_prior_arg,
|
||||
@ -553,6 +588,7 @@ async def test_arg(app: App):
|
||||
)
|
||||
|
||||
matcher = Matcher()
|
||||
event = make_fake_event()()
|
||||
message = FakeMessage("text")
|
||||
matcher.set_arg("key", message)
|
||||
|
||||
@ -582,6 +618,21 @@ async def test_arg(app: App):
|
||||
ctx.pass_params(matcher=matcher)
|
||||
ctx.should_return(message.extract_plain_text())
|
||||
|
||||
matcher.set_target(ARG_KEY.format(key="key"), cache=False)
|
||||
|
||||
async with app.test_api() as ctx:
|
||||
bot = ctx.create_bot()
|
||||
ctx.should_call_send(event, "test", result="arg", bot=bot)
|
||||
with matcher.ensure_context(bot, event):
|
||||
with suppress(RejectedException):
|
||||
await matcher.reject("test")
|
||||
|
||||
async with app.test_dependent(
|
||||
annotated_arg_prompt_result, allow_types=[ArgParam]
|
||||
) as ctx:
|
||||
ctx.pass_params(matcher=matcher)
|
||||
ctx.should_return("arg")
|
||||
|
||||
async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx:
|
||||
ctx.pass_params(matcher=matcher)
|
||||
ctx.should_return(message.extract_plain_text())
|
||||
|
@ -1224,6 +1224,37 @@ async def _(foo: Event = LastReceived()): ...
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### ReceivePromptResult
|
||||
|
||||
获取某次 `receive` 发送提示消息的结果。
|
||||
|
||||
<Tabs groupId="annotated">
|
||||
<TabItem value="annotated" label="Use Annotated" default>
|
||||
|
||||
```python {6}
|
||||
from typing import Any, Annotated
|
||||
|
||||
from nonebot.params import ReceivePromptResult
|
||||
|
||||
@matcher.receive("id", prompt="prompt")
|
||||
async def _(result: Annotated[Any, ReceivePromptResult("id")]): ...
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="no-annotated" label="Without Annotated">
|
||||
|
||||
```python {6}
|
||||
from typing import Any
|
||||
|
||||
from nonebot.params import ReceivePromptResult
|
||||
|
||||
@matcher.receive("id", prompt="prompt")
|
||||
async def _(result: Any = ReceivePromptResult("id")): ...
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Arg
|
||||
|
||||
获取某次 `got` 接收的参数。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。
|
||||
@ -1318,3 +1349,75 @@ async def _(foo: str = ArgPlainText("key")): ...
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### ArgPromptResult
|
||||
|
||||
获取某次 `got` 发送提示消息的结果。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。
|
||||
|
||||
<Tabs groupId="annotated">
|
||||
<TabItem value="annotated" label="Use Annotated" default>
|
||||
|
||||
```python {6,7}
|
||||
from typing import Any, Annotated
|
||||
|
||||
from nonebot.params import ArgPromptResult
|
||||
|
||||
@matcher.got("key", prompt="prompt")
|
||||
async def _(result: Annotated[Any, ArgPromptResult()]): ...
|
||||
async def _(result: Annotated[Any, ArgPromptResult("key")]): ...
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="no-annotated" label="Without Annotated">
|
||||
|
||||
```python {6,7}
|
||||
from typing import Any
|
||||
|
||||
from nonebot.params import ArgPromptResult
|
||||
|
||||
@matcher.got("key", prompt="prompt")
|
||||
async def _(result: Any = ArgPromptResult()): ...
|
||||
async def _(result: Any = ArgPromptResult("key")): ...
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### PausePromptResult
|
||||
|
||||
获取最近一次 `pause` 发送提示消息的结果。
|
||||
|
||||
<Tabs groupId="annotated">
|
||||
<TabItem value="annotated" label="Use Annotated" default>
|
||||
|
||||
```python {6}
|
||||
from typing import Any, Annotated
|
||||
|
||||
from nonebot.params import PausePromptResult
|
||||
|
||||
@matcher.handle()
|
||||
async def _():
|
||||
await matcher.pause(prompt="prompt")
|
||||
|
||||
@matcher.handle()
|
||||
async def _(result: Annotated[Any, PausePromptResult()]): ...
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="no-annotated" label="Without Annotated">
|
||||
|
||||
```python {6}
|
||||
from typing import Any
|
||||
|
||||
from nonebot.params import PausePromptResult
|
||||
|
||||
@matcher.handle()
|
||||
async def _():
|
||||
await matcher.pause(prompt="prompt")
|
||||
|
||||
@matcher.handle()
|
||||
async def _(result: Any = PausePromptResult()): ...
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
Loading…
Reference in New Issue
Block a user