From 39c880e3ef62e777a845efa10c61c97a63c2862e Mon Sep 17 00:00:00 2001 From: Azide Date: Wed, 27 Dec 2023 16:28:44 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E5=8E=BB=E9=99=A4MessageFactory?= =?UTF-8?q?=E5=92=8CMessageSegmentFactory=E7=9A=84=E6=B3=9B=E5=9E=8B?= =?UTF-8?q?=EF=BC=8C=E8=A1=A5=E5=85=85=E6=88=90=E5=91=98=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_saa/abstract_factories.py | 309 +++++++++++++++++- nonebot_plugin_saa/adapters/dodo.py | 10 +- nonebot_plugin_saa/adapters/feishu.py | 16 +- nonebot_plugin_saa/adapters/kaiheila.py | 11 +- nonebot_plugin_saa/adapters/onebot_v11.py | 10 +- nonebot_plugin_saa/adapters/onebot_v12.py | 10 +- nonebot_plugin_saa/adapters/qq.py | 13 +- nonebot_plugin_saa/adapters/qqguild.py | 11 +- nonebot_plugin_saa/adapters/red.py | 18 +- nonebot_plugin_saa/adapters/telegram.py | 16 +- nonebot_plugin_saa/registries/message_id.py | 3 +- nonebot_plugin_saa/registries/receipt.py | 7 + .../types/common_message_segment.py | 18 +- tests/test_feishu.py | 6 +- tests/test_message.py | 233 ++++++++++++- 15 files changed, 626 insertions(+), 65 deletions(-) diff --git a/nonebot_plugin_saa/abstract_factories.py b/nonebot_plugin_saa/abstract_factories.py index ba732e22..7b7848d3 100644 --- a/nonebot_plugin_saa/abstract_factories.py +++ b/nonebot_plugin_saa/abstract_factories.py @@ -4,10 +4,13 @@ from warnings import warn from inspect import signature from typing_extensions import Self +from dataclasses import field, asdict, dataclass from typing import ( + Any, Dict, List, Type, + Tuple, Union, TypeVar, Callable, @@ -16,7 +19,9 @@ NoReturn, Optional, Awaitable, + SupportsIndex, cast, + overload, ) from nonebot.adapters import Bot, Event, Message, MessageSegment @@ -94,6 +99,7 @@ async def do_build_custom(builder: CustomBuildFunc, bot: Bot) -> MessageSegment: return cast(MessageSegment, res) +@dataclass class MessageSegmentFactory(ABC): _builders: ClassVar[ Dict[ @@ -105,8 +111,11 @@ class MessageSegmentFactory(ABC): ] ] - data: dict - _custom_builders: Dict[SupportedAdapters, CustomBuildFunc] + type: str + data: Dict[str, Any] = field(default_factory=dict) + _custom_builders: Dict[SupportedAdapters, CustomBuildFunc] = field( + init=False, default_factory=dict + ) def _register_custom_builder( self, @@ -131,8 +140,23 @@ def __init_subclass__(cls) -> None: cls._builders = {} return super().__init_subclass__() - def __eq__(self, other: Self) -> bool: - return self.data == other.data + def __str__(self) -> str: + kwstr = ",".join(f"{k}={v!r}" for k, v in self.data.items()) + return f"[SAA:{self.type}|{kwstr}]" + + def __repr__(self) -> str: + kwrepr = ", ".join(f"{k}={v!r}" for k, v in self.data.items()) + return f"{self.__class__.__name__}({kwrepr})" + + def __len__(self) -> int: + return len(self.data) + + def __eq__(self, other: object) -> bool: + if isinstance(other, MessageSegmentFactory): + return self.data == other.data + elif isinstance(other, str): + return self.data == {"text": other} + return False def overwrite( self, @@ -151,10 +175,16 @@ async def build(self, bot: Bot) -> MessageSegment: return await do_build(self, builder, bot) raise AdapterNotInstalled(adapter_name) - def __add__(self: TMSF, other: Union[str, TMSF, Iterable[TMSF]]): + def __add__( + self, + other: "str | MessageSegmentFactory | Iterable[str | MessageSegmentFactory]", + ) -> "MessageFactory": return MessageFactory(self) + other - def __radd__(self: TMSF, other: Union[str, TMSF, Iterable[TMSF]]): + def __radd__( + self, + other: "str | MessageSegmentFactory | Iterable[str | MessageSegmentFactory]", + ) -> "MessageFactory": return MessageFactory(other) + self async def send(self, *, at_sender=False, reply=False): @@ -206,13 +236,38 @@ async def reject_receive( await self.send(at_sender=at_sender, reply=reply, **kwargs) await matcher.reject_receive(key) + def copy(self) -> Self: + """深拷贝""" + return deepcopy(self) + + def _asdict(self): + _dict = asdict(self) + return {k: v for k, v in _dict.items() if not k.startswith("_")} + + def get(self, key: str, default: Any = None): + return asdict(self).get(key, default) + + def keys(self): + return self._asdict().keys() -class MessageFactory(List[TMSF]): - _text_factory: Callable[[str], TMSF] + def values(self): + return self._asdict().values() + + def items(self): + return self._asdict().items() + + def join( + self, iterable: "Iterable[MessageSegmentFactory | MessageFactory]" + ) -> "MessageFactory": + return MessageFactory(self).join(iterable) + + +class MessageFactory(List[MessageSegmentFactory]): + _text_factory: Callable[[str], MessageSegmentFactory] _message_registry: Dict[SupportedAdapters, Type[Message]] = {} @classmethod - def register_text_ms(cls, factory: Callable[[str], TMSF]): + def register_text_ms(cls, factory: Callable[[str], MessageSegmentFactory]): cls._text_factory = factory return factory @@ -241,7 +296,10 @@ async def _build(self, bot: Bot) -> Message: return message_type(ms) raise AdapterNotInstalled(adapter_name) - def __init__(self, message: Union[str, Iterable[TMSF], TMSF]): + def __init__( + self, + message: "str | MessageSegmentFactory | Iterable[str | MessageSegmentFactory] | None" = None, # noqa: E501 + ): super().__init__() if message is None: @@ -254,16 +312,25 @@ def __init__(self, message: Union[str, Iterable[TMSF], TMSF]): elif isinstance(message, Iterable): self.extend(message) - def __add__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF: + def __add__( + self: TMF, + other: "str | MessageSegmentFactory | Iterable[str | MessageSegmentFactory]", + ) -> TMF: result = self.copy() result += other return result - def __radd__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF: + def __radd__( + self: TMF, + other: "str | MessageSegmentFactory | Iterable[str | MessageSegmentFactory]", + ) -> TMF: result = self.__class__(other) return result + self - def __iadd__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF: + def __iadd__( + self: TMF, + other: "str | MessageSegmentFactory | Iterable[str | MessageSegmentFactory]", + ) -> TMF: if isinstance(other, str): self.append(self.get_text_factory()(other)) elif isinstance(other, MessageSegmentFactory): @@ -273,7 +340,7 @@ def __iadd__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF: return self - def append(self: TMF, obj: Union[str, TMSF]) -> TMF: + def append(self: TMF, obj: Union[str, MessageSegmentFactory]) -> TMF: if isinstance(obj, MessageSegmentFactory): super().append(obj) elif isinstance(obj, str): @@ -281,15 +348,39 @@ def append(self: TMF, obj: Union[str, TMSF]) -> TMF: return self - def extend(self: TMF, obj: Union[TMF, Iterable[TMSF]]) -> TMF: + def extend( + self: TMF, obj: Union[TMF, Iterable[Union[str, MessageSegmentFactory]]] + ) -> TMF: for message_segment_factory in obj: self.append(message_segment_factory) return self - def copy(self: TMF) -> TMF: + def copy(self) -> Self: return deepcopy(self) + def join(self, iterable: "Iterable[MessageSegmentFactory | Self]") -> Self: + """将多个消息连接并将自身作为分割 + + 参数: + iterable: 要连接的消息 + + 返回: + 连接后的消息 + """ + ret = self.__class__() + for index, msg in enumerate(iterable): + if index != 0: + ret.extend(self) + if isinstance(msg, MessageSegmentFactory): + ret.append(msg.copy()) + else: + ret.extend(msg.copy()) + return ret + + def __str__(self) -> str: + return "".join(str(ms_factory) for ms_factory in self) + async def send(self, *, at_sender=False, reply=False) -> "Receipt": "回复消息,仅能用在事件响应器中" try: @@ -365,6 +456,192 @@ async def _do_send( ) # pragma: no cover return await sender(bot, self, target, event, at_sender, reply) + @overload + def __getitem__(self, args: str) -> Self: + """获取仅包含指定消息段类型的消息 + + 参数: + args: 消息段类型 + + 返回: + 所有类型为 `args` 的消息段 + """ + + @overload + def __getitem__(self, args: Tuple[str, int]) -> MessageSegmentFactory: + """索引指定类型的消息段 + + 参数: + args: 消息段类型和索引 + + 返回: + 类型为 `args[0]` 的消息段第 `args[1]` 个 + """ + + @overload + def __getitem__(self, args: Tuple[str, slice]) -> Self: + """切片指定类型的消息段 + + 参数: + args: 消息段类型和切片 + + 返回: + 类型为 `args[0]` 的消息段切片 `args[1]` + """ + + @overload + def __getitem__(self, args: int) -> MessageSegmentFactory: + """索引消息段 + + 参数: + args: 索引 + + 返回: + 第 `args` 个消息段 + """ + + @overload + def __getitem__(self, args: slice) -> Self: + """切片消息段 + + 参数: + args: 切片 + + 返回: + 消息切片 `args` + """ + + def __getitem__( + self, + args: Union[ + str, + Tuple[str, int], + Tuple[str, slice], + int, + slice, + ], + ) -> Union[MessageSegmentFactory, Self]: + arg1, arg2 = args if isinstance(args, tuple) else (args, None) + if isinstance(arg1, int) and arg2 is None: + return super().__getitem__(arg1) + elif isinstance(arg1, slice) and arg2 is None: + return self.__class__(super().__getitem__(arg1)) + elif isinstance(arg1, str) and arg2 is None: + return self.__class__(seg for seg in self if seg.type == arg1) + elif isinstance(arg1, str) and isinstance(arg2, int): + return [seg for seg in self if seg.type == arg1][arg2] + elif isinstance(arg1, str) and isinstance(arg2, slice): + return self.__class__([seg for seg in self if seg.type == arg1][arg2]) + else: + raise ValueError("Incorrect arguments to slice") # pragma: no cover + + def __contains__(self, value: Union[MessageSegmentFactory, str]) -> bool: + """检查消息段是否存在 + + 参数: + value: 消息段或消息段类型 + 返回: + 消息内是否存在给定消息段或给定类型的消息段 + """ + if isinstance(value, str): + return bool(next((seg for seg in self if seg.type == value), None)) + return super().__contains__(value) + + def has(self, value: Union[MessageSegmentFactory, str]) -> bool: + """与 {ref}``__contains__` ` 相同""" + return value in self + + def index( + self, value: Union[MessageSegmentFactory, str], *args: SupportsIndex + ) -> int: + """索引消息段 + + 参数: + value: 消息段或者消息段类型 + arg: start 与 end + + 返回: + 索引 index + + 异常: + ValueError: 消息段不存在 + """ + if isinstance(value, str): + first_segment = next((seg for seg in self if seg.type == value), None) + if first_segment is None: + raise ValueError(f"Segment with type {value!r} is not in message") + return super().index(first_segment, *args) + return super().index(value, *args) + + def get(self, type_: str, count: Optional[int] = None): + """获取指定类型的消息段 + + 参数: + type_: 消息段类型 + count: 获取个数 + + 返回: + 构建的新消息 + """ + if count is None: + return self[type_] + + iterator, filtered = ( + seg for seg in self if seg.type == type_ + ), self.__class__() + for _ in range(count): + seg = next(iterator, None) + if seg is None: + break + filtered.append(seg) + return filtered + + def count(self, value: Union[MessageSegmentFactory, str]) -> int: + """计算指定消息段的个数 + + 参数: + value: 消息段或消息段类型 + + 返回: + 个数 + """ + return len(self[value]) if isinstance(value, str) else super().count(value) + + def only(self, value: Union[MessageSegmentFactory, str]) -> bool: + """检查消息中是否仅包含指定消息段 + + 参数: + value: 指定消息段或消息段类型 + + 返回: + 是否仅包含指定消息段 + """ + if isinstance(value, str): + return all(seg.type == value for seg in self) + return all(seg == value for seg in self) + + def include(self, *types: str) -> Self: + """过滤消息 + + 参数: + types: 包含的消息段类型 + + 返回: + 新构造的消息 + """ + return self.__class__(seg for seg in self if seg.type in types) + + def exclude(self, *types: str) -> Self: + """过滤消息 + + 参数: + types: 不包含的消息段类型 + + 返回: + 新构造的消息 + """ + return self.__class__(seg for seg in self if seg.type not in types) + AggregatedSender = Callable[ [Bot, List[MessageFactory], PlatformTarget, Optional[Event]], diff --git a/nonebot_plugin_saa/adapters/dodo.py b/nonebot_plugin_saa/adapters/dodo.py index 0f6d5f8c..7fdb0ea8 100644 --- a/nonebot_plugin_saa/adapters/dodo.py +++ b/nonebot_plugin_saa/adapters/dodo.py @@ -12,7 +12,6 @@ from ..utils import SupportedAdapters, SupportedPlatform from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -91,8 +90,8 @@ async def _image(image: Image, bot: BaseBot) -> MessageSegment: @register_dodo(Reply) def _reply(reply: Reply) -> MessageSegment: - assert isinstance(reply.data, DodoMessageId) - return MessageSegment.reference(reply.data.message_id) + assert isinstance(mid := reply.data["message_id"], DodoMessageId) + return MessageSegment.reference(mid.message_id) @register_dodo(Mention) def _mention(mention: Mention) -> MessageSegment: @@ -181,10 +180,13 @@ async def pin(self, is_cancel: bool = False): def raw(self) -> str: return self.message_id + def extract_message_id(self) -> DodoMessageId: + return DodoMessageId(message_id=self.message_id) + @register_sender(adapter) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/feishu.py b/nonebot_plugin_saa/adapters/feishu.py index f7bf417b..525bfd4b 100644 --- a/nonebot_plugin_saa/adapters/feishu.py +++ b/nonebot_plugin_saa/adapters/feishu.py @@ -10,7 +10,6 @@ from ..types import Text, Image, Reply, Mention from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -78,8 +77,8 @@ def _mention(m: Mention) -> MessageSegment: @register_feishu(Reply) def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, FeishuMessageId) - return MessageSegment("reply", {"message_id": r.data.message_id}) + assert isinstance(mid := r.data["message_id"], FeishuMessageId) + return MessageSegment("reply", {"message_id": mid.message_id}) @register_target_extractor(PrivateMessageEvent) def _extract_private_msg_event(event: Event) -> TargetFeishuPrivate: @@ -105,6 +104,9 @@ async def revoke(self): def raw(self) -> Any: return self.data + def extract_message_id(self) -> FeishuMessageId: + return FeishuMessageId(message_id=self.message_id) + @register_message_id_getter(MessageEvent) def _(event: Event) -> FeishuMessageId: assert isinstance(event, MessageEvent) @@ -113,7 +115,7 @@ def _(event: Event) -> FeishuMessageId: @register_sender(adapter) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, @@ -142,8 +144,10 @@ async def send( message_to_send = Message() for message_segment_factory in full_msg: if isinstance(message_segment_factory, Reply): - assert isinstance(message_segment_factory.data, FeishuMessageId) - reply_to_message_id = message_segment_factory.data.message_id + assert isinstance( + mid := message_segment_factory.data["message_id"], FeishuMessageId + ) + reply_to_message_id = mid.message_id continue message_segment = await message_segment_factory.build(bot) diff --git a/nonebot_plugin_saa/adapters/kaiheila.py b/nonebot_plugin_saa/adapters/kaiheila.py index a5b0e5e1..ea84491c 100644 --- a/nonebot_plugin_saa/adapters/kaiheila.py +++ b/nonebot_plugin_saa/adapters/kaiheila.py @@ -9,7 +9,6 @@ from ..utils import SupportedAdapters, SupportedPlatform from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -84,8 +83,8 @@ def _mention(m: Mention) -> MessageSegment: @register_kaiheila(Reply) def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, KaiheilaMessageId) - return MessageSegment.quote(r.data.message_id) + assert isinstance(mid := r.data["message_id"], KaiheilaMessageId) + return MessageSegment.quote(mid.message_id) @register_target_extractor(PrivateMessageEvent) def _extract_private_msg_event(event: Event) -> TargetKaiheilaPrivate: @@ -125,6 +124,10 @@ async def revoke(self): def raw(self) -> MessageCreateReturn: return self.data + def extract_message_id(self) -> MessageId: + assert self.data.msg_id + return KaiheilaMessageId(message_id=self.data.msg_id) + @register_message_id_getter(MessageEvent) def _(event: Event) -> KaiheilaMessageId: assert isinstance(event, MessageEvent) @@ -133,7 +136,7 @@ def _(event: Event) -> KaiheilaMessageId: @register_sender(SupportedAdapters.kaiheila) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/onebot_v11.py b/nonebot_plugin_saa/adapters/onebot_v11.py index 75478e6d..ec38b78b 100644 --- a/nonebot_plugin_saa/adapters/onebot_v11.py +++ b/nonebot_plugin_saa/adapters/onebot_v11.py @@ -8,7 +8,6 @@ from ..utils import SupportedAdapters, SupportedPlatform from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, AggregatedMessageFactory, register_ms_adapter, assamble_message_factory, @@ -71,8 +70,8 @@ async def _mention(m: Mention) -> MessageSegment: @register_onebot_v11(Reply) async def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, OB11MessageId) - return MessageSegment.reply(r.data.message_id) + assert isinstance(mid := r.data["message_id"], OB11MessageId) + return MessageSegment.reply(mid.message_id) @register_target_extractor(PrivateMessageEvent) def _extract_private_msg_event(event: Event) -> TargetQQPrivate: @@ -170,10 +169,13 @@ async def revoke(self): def raw(self) -> Any: return self.message_id + def extract_message_id(self) -> OB11MessageId: + return OB11MessageId(message_id=self.message_id) + @register_sender(SupportedAdapters.onebot_v11) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/onebot_v12.py b/nonebot_plugin_saa/adapters/onebot_v12.py index 53281a18..0189bcc9 100644 --- a/nonebot_plugin_saa/adapters/onebot_v12.py +++ b/nonebot_plugin_saa/adapters/onebot_v12.py @@ -11,7 +11,6 @@ from ..utils import SupportedAdapters, SupportedPlatform from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -97,8 +96,8 @@ async def _mention(m: Mention) -> MessageSegment: @register_onebot_v12(Reply) async def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, OB12MessageId) - return MessageSegment.reply(r.data.message_id) + assert isinstance(mid := r.data["message_id"], OB12MessageId) + return MessageSegment.reply(mid.message_id) @register_target_extractor(PrivateMessageEvent) def _extract_private_msg_event(event: Event) -> PlatformTarget: @@ -254,10 +253,13 @@ async def revoke(self): def raw(self): return self.message_id + def extract_message_id(self) -> OB12MessageId: + return OB12MessageId(message_id=self.message_id) + @register_sender(SupportedAdapters.onebot_v12) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/qq.py b/nonebot_plugin_saa/adapters/qq.py index 1ec417ba..5094e424 100644 --- a/nonebot_plugin_saa/adapters/qq.py +++ b/nonebot_plugin_saa/adapters/qq.py @@ -9,7 +9,6 @@ from ..auto_select_bot import register_list_targets from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -73,8 +72,8 @@ def _mention(m: Mention) -> MessageSegment: @register_qq(Reply) def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, QQMessageId) - return MessageSegment.reference(r.data.message_id) + assert isinstance(mid := r.data["message_id"], QQMessageId) + return MessageSegment.reference(mid.message_id) @register_target_extractor(GuildMessageEvent) def extract_message_event(event: Event) -> PlatformTarget: @@ -138,10 +137,16 @@ async def revoke(self, hidetip=False): def raw(self): return self.msg_return + def extract_message_id(self) -> QQMessageId: + assert hasattr(self.msg_return, "id") + id = getattr(self.msg_return, "id") + assert isinstance(id, str) + return QQMessageId(message_id=id) + @register_sender(SupportedAdapters.qq) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target: PlatformTarget, event: Optional[Event], at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/qqguild.py b/nonebot_plugin_saa/adapters/qqguild.py index 5c46d394..ed919f13 100644 --- a/nonebot_plugin_saa/adapters/qqguild.py +++ b/nonebot_plugin_saa/adapters/qqguild.py @@ -9,7 +9,6 @@ from ..auto_select_bot import register_list_targets from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -63,8 +62,8 @@ def _mention(m: Mention) -> MessageSegment: @register_qqguild(Reply) def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, QQGuildMessageId) - return MessageSegment.reference(r.data.message_id) + assert isinstance(mid := r.data["message_id"], QQGuildMessageId) + return MessageSegment.reference(mid.message_id) @register_target_extractor(MessageEvent) def extract_message_event(event: Event) -> PlatformTarget: @@ -108,10 +107,14 @@ async def revoke(self, hidetip=False): def raw(self): return self.sent_msg + def extract_message_id(self) -> QQGuildMessageId: + assert self.sent_msg.id + return QQGuildMessageId(message_id=self.sent_msg.id) + @register_sender(SupportedAdapters.qqguild) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/red.py b/nonebot_plugin_saa/adapters/red.py index 9ed7a2b4..48642755 100644 --- a/nonebot_plugin_saa/adapters/red.py +++ b/nonebot_plugin_saa/adapters/red.py @@ -11,7 +11,6 @@ from ..utils import SupportedAdapters, SupportedPlatform from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, AggregatedMessageFactory, register_ms_adapter, assamble_message_factory, @@ -72,11 +71,11 @@ async def _mention(m: Mention) -> MessageSegment: @register_red(Reply) async def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, RedMessageId) + assert isinstance(mid := r.data["message_id"], RedMessageId) return MessageSegment.reply( - message_seq=r.data.message_seq, - message_id=r.data.message_id, - sender_uin=r.data.sender_uin, + message_seq=mid.message_seq, + message_id=mid.message_id, + sender_uin=mid.sender_uin, ) @register_target_extractor(PrivateMessageEvent) @@ -123,10 +122,17 @@ async def revoke(self): def raw(self) -> MessageModel: return self.message + def extract_message_id(self) -> RedMessageId: + return RedMessageId( + message_seq=self.message.msgSeq, + message_id=self.message.msgId, + sender_uin=self.message.senderUin, + ) + @register_sender(SupportedAdapters.red) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, diff --git a/nonebot_plugin_saa/adapters/telegram.py b/nonebot_plugin_saa/adapters/telegram.py index e8adf551..77c0fd88 100644 --- a/nonebot_plugin_saa/adapters/telegram.py +++ b/nonebot_plugin_saa/adapters/telegram.py @@ -11,7 +11,6 @@ from ..types import Text, Image, Reply, Mention from ..abstract_factories import ( MessageFactory, - MessageSegmentFactory, register_ms_adapter, assamble_message_factory, ) @@ -71,8 +70,8 @@ async def _mention(m: Mention) -> MessageSegment: @register_telegram(Reply) async def _reply(r: Reply) -> MessageSegment: - assert isinstance(r.data, TelegramMessageId) - return MessageSegment("reply", {"message_id": str(r.data.message_id)}) + assert isinstance(mid := r.data["message_id"], TelegramMessageId) + return MessageSegment("reply", {"message_id": str(mid.message_id)}) @register_target_extractor(PrivateMessageEvent) @register_target_extractor(GroupMessageEvent) @@ -129,6 +128,9 @@ async def revoke(self): def raw(self): return self.messages + def extract_message_id(self) -> List[TelegramMessageId]: + return [TelegramMessageId(message_id=x.message_id) for x in self.messages] + @register_message_id_getter(MessageEvent) def _(event: Event): assert isinstance(event, MessageEvent) @@ -137,7 +139,7 @@ def _(event: Event): @register_sender(SupportedAdapters.telegram) async def send( bot, - msg: MessageFactory[MessageSegmentFactory], + msg: MessageFactory, target, event, at_sender: bool, @@ -166,8 +168,10 @@ async def send( message_to_send = Message() for message_segment_factory in full_msg: if isinstance(message_segment_factory, Reply): - assert isinstance(message_segment_factory.data, TelegramMessageId) - reply_to_message_id = message_segment_factory.data.message_id + assert isinstance( + mid := message_segment_factory.data["message_id"], TelegramMessageId + ) + reply_to_message_id = mid.message_id continue if ( diff --git a/nonebot_plugin_saa/registries/message_id.py b/nonebot_plugin_saa/registries/message_id.py index 2a88e536..6f3a8fc9 100644 --- a/nonebot_plugin_saa/registries/message_id.py +++ b/nonebot_plugin_saa/registries/message_id.py @@ -1,3 +1,4 @@ +from abc import ABC from typing_extensions import Annotated from typing import Dict, Type, Callable, Optional @@ -8,7 +9,7 @@ from ..utils import SupportedAdapters -class MessageId(SerializationMeta): +class MessageId(SerializationMeta, ABC): _index_key = "adapter_name" adapter_name: SupportedAdapters diff --git a/nonebot_plugin_saa/registries/receipt.py b/nonebot_plugin_saa/registries/receipt.py index 73ed0554..f05cf354 100644 --- a/nonebot_plugin_saa/registries/receipt.py +++ b/nonebot_plugin_saa/registries/receipt.py @@ -1,8 +1,10 @@ from typing import Any +from abc import abstractmethod from nonebot import get_bot from nonebot.adapters import Bot +from .message_id import MessageId from .meta import SerializationMeta from ..utils import SupportedAdapters @@ -22,3 +24,8 @@ async def revoke(self): @property def raw(self) -> Any: ... + + @abstractmethod + def extract_message_id(self) -> MessageId: + """从 Receipt 中提取 MessageId""" + ... diff --git a/nonebot_plugin_saa/types/common_message_segment.py b/nonebot_plugin_saa/types/common_message_segment.py index 7d28aa1f..80c27597 100644 --- a/nonebot_plugin_saa/types/common_message_segment.py +++ b/nonebot_plugin_saa/types/common_message_segment.py @@ -1,6 +1,6 @@ from io import BytesIO from pathlib import Path -from typing import Union, TypedDict +from typing import Union, Literal, TypedDict from ..registries import MessageId from ..abstract_factories import MessageFactory, MessageSegmentFactory @@ -13,6 +13,7 @@ class TextData(TypedDict): class Text(MessageSegmentFactory): """文本消息段""" + type: Literal["text"] = "text" data: TextData def __init__(self, text: str) -> None: @@ -27,6 +28,9 @@ def __init__(self, text: str) -> None: def __str__(self) -> str: return self.data["text"] + def __len__(self) -> int: + return len(self.data["text"]) + MessageFactory.register_text_ms(lambda text: Text(text)) @@ -39,6 +43,7 @@ class ImageData(TypedDict): class Image(MessageSegmentFactory): """图片消息段""" + type: Literal["image"] = "image" data: ImageData def __init__( @@ -65,6 +70,7 @@ class MentionData(TypedDict): class Mention(MessageSegmentFactory): """提到其他用户""" + type: Literal["mention"] = "mention" data: MentionData def __init__(self, user_id: str): @@ -78,10 +84,15 @@ def __init__(self, user_id: str): self.data = {"user_id": user_id} +class ReplyData(TypedDict): + message_id: MessageId + + class Reply(MessageSegmentFactory): """回复其他消息的消息段""" - data: MessageId + type: Literal["reply"] = "reply" + data: ReplyData def __init__(self, message_id: MessageId): """回复其他消息的消息段 @@ -89,6 +100,5 @@ def __init__(self, message_id: MessageId): 参数: message_id: 需要回复消息的 MessageId """ - super().__init__() - self.data = message_id + self.data = {"message_id": message_id} diff --git a/tests/test_feishu.py b/tests/test_feishu.py index 953ba993..be8a5be5 100644 --- a/tests/test_feishu.py +++ b/tests/test_feishu.py @@ -5,8 +5,8 @@ import httpx from nonebug import App from nonebot import get_driver -from nonebot.adapters.feishu.bot import BotInfo from nonebot.adapters.feishu import Bot, Message +from nonebot.adapters.feishu.models import BotInfo from nonebot.adapters.feishu.config import BotConfig from nonebot_plugin_saa.utils import SupportedAdapters @@ -86,6 +86,8 @@ def mock_feishu_message_event(message: Message, group=False): message=PrivateEventMessage(chat_type="p2p", **event_message_dict), ), reply=None, + _message=message, + original_message=message, ) else: return GroupMessageEvent( @@ -96,6 +98,8 @@ def mock_feishu_message_event(message: Message, group=False): message=GroupEventMessage(chat_type="group", **event_message_dict), ), reply=None, + _message=message, + original_message=message, ) diff --git a/tests/test_message.py b/tests/test_message.py index 06f92477..20cc292e 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -4,8 +4,9 @@ from nonebot.adapters.onebot.v11.bot import Bot from nonebot.adapters.onebot.v11.message import MessageSegment -from nonebot_plugin_saa import Text, MessageFactory +from nonebot_plugin_saa.registries import MessageId from nonebot_plugin_saa.utils import SupportedAdapters +from nonebot_plugin_saa import Text, Image, Reply, Mention, MessageFactory def test_message_assamble(): @@ -34,3 +35,233 @@ async def test_build_message(app: App): msg = await msg_factory.build(bot) assert msg == MessageSegment.text("talk is cheap") + "show me the code" + + +def test_message_add(app: App): + s = "123" + t = Text("abc") + i = Image("http://example.com/abc.png") + r = Reply(MessageId(adapter_name=SupportedAdapters.fake)) + m = Mention("123") + + assert t == "abc" + assert t != "123" + assert t == Text("abc") + assert t != Text("123") + assert t != i + assert t != r + assert t != m + st = s + t + assert st == MessageFactory([Text("123"), Text("abc")]) + tt = t + t + assert tt == MessageFactory([Text("abc"), Text("abc")]) + ts = t + s + assert ts == MessageFactory([Text("abc"), Text("123")]) + assert ts == MessageFactory([t, s]) + + si = s + i + assert si == MessageFactory([Text("123"), Image("http://example.com/abc.png")]) + + is_ = i + s + assert is_ == MessageFactory([Image("http://example.com/abc.png"), Text("123")]) + + ti = t + i + assert ti == MessageFactory([Text("abc"), Image("http://example.com/abc.png")]) + + it = i + t + assert it == MessageFactory([Image("http://example.com/abc.png"), Text("abc")]) + + sit = s + i + t + assert sit == MessageFactory( + [Text("123"), Image("http://example.com/abc.png"), Text("abc")] + ) + assert sit == MessageFactory([s, i, t]) + assert sit == [s, i] + t + + tit = t + i + t + t_it = t + [i, t] + assert tit == MessageFactory( + [Text("abc"), Image("http://example.com/abc.png"), Text("abc")] + ) + assert tit == MessageFactory([t, i, t]) + assert tit == t_it + + tir = t + i + r + assert tir == t + [i, r] + assert tir == MessageFactory([t, i, r]) + ir = i + r + assert ti + r == t + ir + + tt_iadd = Text("q") + tt_iadd += t + assert tt_iadd == MessageFactory([Text("q"), t]) + + +def test_segment_data(): + assert len(Text("text")) == 4 + assert Text("text").get("data") == {"text": "text"} + assert list(Text("text").keys()) == ["type", "data"] + assert list(Text("text").values()) == ["text", {"text": "text"}] + assert list(Text("text").items()) == [ + ("type", "text"), + ("data", {"text": "text"}), + ] + + +def test_segment_join(): + seg = Text("test") + iterable = [ + Text("first"), + MessageFactory([Text("second"), Text("third")]), + ] + + assert seg.join(iterable) == MessageFactory( + [ + Text("first"), + Text("test"), + Text("second"), + Text("third"), + ] + ) + + +def test_segment_copy(): + origin = Text("text") + copy = origin.copy() + assert origin is not copy + assert origin == copy + + +def test_message_getitem(): + message = MessageFactory( + [ + Text("test"), + Image("test2"), + Image("test3"), + Text("test4"), + ] + ) + + assert message[0] == Text("test") + + assert message[:2] == MessageFactory([Text("test"), Image("test2")]) + + assert message["image"] == MessageFactory([Image("test2"), Image("test3")]) + + assert message["image", 0] == Image("test2") + assert message["image", 0:2] == message["image"] + + assert message.index(message[0]) == 0 + assert message.index("image") == 1 + + assert message.get("image") == message["image"] + assert message.get("image", 114514) == message["image"] + assert message.get("image", 1) == MessageFactory([message["image", 0]]) + + assert message.count("image") == 2 + + +def test_message_contains(): + message = MessageFactory( + [ + Text("test"), + Image("test2"), + Image("test3"), + Text("test4"), + ] + ) + + assert message.has(Text("test")) is True + assert Text("test") in message + assert message.has("image") is True + assert "image" in message + + assert message.has(Text("foo")) is False + assert Text("foo") not in message + assert message.has("foo") is False + assert "foo" not in message + + +def test_message_only(): + message = MessageFactory( + [ + Text("test"), + Text("test2"), + ] + ) + + assert message.only("text") is True + assert message.only(Text("test")) is False + + message = MessageFactory( + [ + Text("test"), + Image("test2"), + Image("test3"), + Text("test4"), + ] + ) + + assert message.only("text") is False + + message = MessageFactory( + [ + Text("test"), + Text("test"), + ] + ) + + assert message.only(Text("test")) is True + + +def test_message_join(): + msg = MessageFactory([Text("test")]) + iterable = [ + Text("first"), + MessageFactory([Text("second"), Text("third")]), + ] + + assert msg.join(iterable) == MessageFactory( + [ + Text("first"), + Text("test"), + Text("second"), + Text("third"), + ] + ) + + +def test_message_include(): + message = MessageFactory( + [ + Text("test"), + Image("test2"), + Image("test3"), + Text("test4"), + ] + ) + + assert message.include("text") == MessageFactory( + [ + Text("test"), + Text("test4"), + ] + ) + + +def test_message_exclude(): + message = MessageFactory( + [ + Text("test"), + Image("test2"), + Image("test3"), + Text("test4"), + ] + ) + + assert message.exclude("image") == MessageFactory( + [ + Text("test"), + Text("test4"), + ] + )