From 8c40318c3d8ce3f55f1a426acf5182567ce1cd3b Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 14 Jul 2023 19:05:02 +0200 Subject: [PATCH 1/3] feat: support threads for editing/deleting webhook messages --- disnake/webhook/async_.py | 76 ++++++++++++++++++++++++++++++--------- disnake/webhook/sync.py | 61 ++++++++++++++++++++++++------- 2 files changed, 109 insertions(+), 28 deletions(-) diff --git a/disnake/webhook/async_.py b/disnake/webhook/async_.py index edd9ec3dcd..021817dd07 100644 --- a/disnake/webhook/async_.py +++ b/disnake/webhook/async_.py @@ -293,6 +293,7 @@ def execute_webhook( params = {"wait": int(wait)} if thread_id: params["thread_id"] = thread_id + route = Route( "POST", "/webhooks/{webhook_id}/{webhook_token}", @@ -310,7 +311,12 @@ def get_webhook_message( message_id: int, *, session: aiohttp.ClientSession, + thread_id: Optional[int] = None, ) -> Response[MessagePayload]: + params: Dict[str, Any] = {} + if thread_id is not None: + params["thread_id"] = thread_id + route = Route( "GET", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -318,7 +324,7 @@ def get_webhook_message( webhook_token=token, message_id=message_id, ) - return self.request(route, session) + return self.request(route, session, params=params) def edit_webhook_message( self, @@ -330,7 +336,12 @@ def edit_webhook_message( payload: Optional[Dict[str, Any]] = None, multipart: Optional[List[Dict[str, Any]]] = None, files: Optional[List[File]] = None, + thread_id: Optional[int] = None, ) -> Response[Message]: + params: Dict[str, Any] = {} + if thread_id is not None: + params["thread_id"] = thread_id + route = Route( "PATCH", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -338,7 +349,9 @@ def edit_webhook_message( webhook_token=token, message_id=message_id, ) - return self.request(route, session, payload=payload, multipart=multipart, files=files) + return self.request( + route, session, payload=payload, multipart=multipart, files=files, params=params + ) def delete_webhook_message( self, @@ -347,7 +360,12 @@ def delete_webhook_message( message_id: int, *, session: aiohttp.ClientSession, + thread_id: Optional[int] = None, ) -> Response[None]: + params: Dict[str, Any] = {} + if thread_id is not None: + params["thread_id"] = thread_id + route = Route( "DELETE", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -355,7 +373,7 @@ def delete_webhook_message( webhook_token=token, message_id=message_id, ) - return self.request(route, session) + return self.request(route, session, params=params) def fetch_webhook( self, @@ -684,10 +702,14 @@ def __getattr__(self, attr) -> NoReturn: class _WebhookState(Generic[WebhookT]): - __slots__ = ("_parent", "_webhook") + __slots__ = ("_parent", "_webhook", "_thread") def __init__( - self, webhook: WebhookT, parent: Optional[Union[ConnectionState, _WebhookState]] + self, + webhook: WebhookT, + parent: Optional[Union[ConnectionState, _WebhookState]], + *, + thread: Optional[Snowflake] = None, ) -> None: self._webhook: WebhookT = webhook @@ -697,6 +719,8 @@ def __init__( else: self._parent = parent + self._thread: Optional[Snowflake] = thread + def _get_guild(self, guild_id): if self._parent is not None: return self._parent._get_guild(guild_id) @@ -854,6 +878,7 @@ async def edit( view=view, components=components, allowed_mentions=allowed_mentions, + thread=self._state._thread, ) async def delete(self, *, delay: Optional[float] = None) -> None: @@ -881,13 +906,13 @@ async def delete(self, *, delay: Optional[float] = None) -> None: async def inner_call(delay: float = delay) -> None: await asyncio.sleep(delay) try: - await self._state._webhook.delete_message(self.id) + await self._state._webhook.delete_message(self.id, thread=self._state._thread) except HTTPException: pass asyncio.create_task(inner_call()) else: - await self._state._webhook.delete_message(self.id) + await self._state._webhook.delete_message(self.id, thread=self._state._thread) class BaseWebhook(Hashable): @@ -1415,8 +1440,8 @@ async def edit( return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state) - def _create_message(self, data): - state = _WebhookState(self, parent=self._state) + def _create_message(self, data, thread: Optional[Snowflake] = None): + state = _WebhookState(self, parent=self._state, thread=thread) # state may be artificial (unlikely at this point...) channel_id = int(data["channel_id"]) # if the channel ID does not match, a new thread was created @@ -1574,7 +1599,7 @@ async def send( .. versionadded:: 2.4 thread: :class:`~disnake.abc.Snowflake` - The thread to send this webhook to. + The thread to send this message to. .. versionadded:: 2.0 @@ -1685,6 +1710,7 @@ async def send( flags=flags, view=view, components=components, + # TODO: check message.edit for this thread_name=thread_name, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, @@ -1710,7 +1736,7 @@ async def send( msg = None if wait: - msg = self._create_message(data) + msg = self._create_message(data, thread=thread) if delete_after is not MISSING: await msg.delete(delay=delete_after) @@ -1720,7 +1746,7 @@ async def send( return msg - async def fetch_message(self, id: int) -> WebhookMessage: + async def fetch_message(self, id: int, *, thread: Optional[Snowflake] = None) -> WebhookMessage: """|coro| Retrieves a single :class:`WebhookMessage` owned by this webhook. @@ -1734,6 +1760,10 @@ async def fetch_message(self, id: int) -> WebhookMessage: ---------- id: :class:`int` The message ID to look for. + thread: Optional[:class:`~disnake.abc.Snowflake`] + The thread the message is in, if any. + + .. versionadded:: 2.10 Raises ------ @@ -1760,8 +1790,9 @@ async def fetch_message(self, id: int) -> WebhookMessage: self.token, id, session=self.session, + thread_id=thread.id if thread else None, ) - return self._create_message(data) + return self._create_message(data, thread=thread) async def edit_message( self, @@ -1776,6 +1807,7 @@ async def edit_message( view: Optional[View] = MISSING, components: Optional[Components[MessageUIComponent]] = MISSING, allowed_mentions: Optional[AllowedMentions] = None, + thread: Optional[Snowflake] = None, ) -> WebhookMessage: """|coro| @@ -1851,6 +1883,10 @@ async def edit_message( allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. + thread: Optional[:class:`~disnake.abc.Snowflake`] + The thread the message is in, if any. + + .. versionadded:: 2.10 Raises ------ @@ -1883,7 +1919,7 @@ async def edit_message( # if no attachment list was provided but we're uploading new files, # use current attachments as the base if attachments is MISSING and (file or files): - attachments = (await self.fetch_message(message_id)).attachments + attachments = (await self.fetch_message(message_id, thread=thread)).attachments previous_mentions: Optional[AllowedMentions] = getattr( self._state, "allowed_mentions", None @@ -1907,6 +1943,7 @@ async def edit_message( self.token, message_id, session=self.session, + thread_id=thread.id if thread else None, payload=params.payload, multipart=params.multipart, files=params.files, @@ -1916,12 +1953,14 @@ async def edit_message( for f in params.files: f.close() - message = self._create_message(data) + message = self._create_message(data, thread=thread) if view and not view.is_finished(): self._state.store_view(view, message_id) return message - async def delete_message(self, message_id: int, /) -> None: + async def delete_message( + self, message_id: int, /, *, thread: Optional[Snowflake] = None + ) -> None: """|coro| Deletes a message owned by this webhook. @@ -1938,6 +1977,10 @@ async def delete_message(self, message_id: int, /) -> None: ---------- message_id: :class:`int` The ID of the message to delete. + thread: Optional[:class:`~disnake.abc.Snowflake`] + The thread the message is in, if any. + + .. versionadded:: 2.10 Raises ------ @@ -1957,4 +2000,5 @@ async def delete_message(self, message_id: int, /) -> None: self.token, message_id, session=self.session, + thread_id=thread.id if thread else None, ) diff --git a/disnake/webhook/sync.py b/disnake/webhook/sync.py index df14470637..8a6efa2c6d 100644 --- a/disnake/webhook/sync.py +++ b/disnake/webhook/sync.py @@ -271,6 +271,7 @@ def execute_webhook( params = {"wait": int(wait)} if thread_id: params["thread_id"] = thread_id + route = Route( "POST", "/webhooks/{webhook_id}/{webhook_token}", @@ -288,7 +289,12 @@ def get_webhook_message( message_id: int, *, session: Session, + thread_id: Optional[int] = None, ): + params: Dict[str, Any] = {} + if thread_id is not None: + params["thread_id"] = thread_id + route = Route( "GET", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -296,7 +302,7 @@ def get_webhook_message( webhook_token=token, message_id=message_id, ) - return self.request(route, session) + return self.request(route, session, params=params) def edit_webhook_message( self, @@ -308,7 +314,12 @@ def edit_webhook_message( payload: Optional[Dict[str, Any]] = None, multipart: Optional[List[Dict[str, Any]]] = None, files: Optional[List[File]] = None, + thread_id: Optional[int] = None, ): + params: Dict[str, Any] = {} + if thread_id is not None: + params["thread_id"] = thread_id + route = Route( "PATCH", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -316,7 +327,9 @@ def edit_webhook_message( webhook_token=token, message_id=message_id, ) - return self.request(route, session, payload=payload, multipart=multipart, files=files) + return self.request( + route, session, payload=payload, multipart=multipart, files=files, params=params + ) def delete_webhook_message( self, @@ -325,7 +338,12 @@ def delete_webhook_message( message_id: int, *, session: Session, + thread_id: Optional[int] = None, ): + params: Dict[str, Any] = {} + if thread_id is not None: + params["thread_id"] = thread_id + route = Route( "DELETE", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -333,7 +351,7 @@ def delete_webhook_message( webhook_token=token, message_id=message_id, ) - return self.request(route, session) + return self.request(route, session, params=params) def fetch_webhook( self, @@ -475,6 +493,7 @@ def edit( files=files, attachments=attachments, allowed_mentions=allowed_mentions, + thread=self._state._thread, ) def delete(self, *, delay: Optional[float] = None) -> None: @@ -502,7 +521,7 @@ def delete(self, *, delay: Optional[float] = None) -> None: """ if delay is not None: time.sleep(delay) - self._state._webhook.delete_message(self.id) + self._state._webhook.delete_message(self.id, thread=self._state._thread) class SyncWebhook(BaseWebhook): @@ -845,8 +864,8 @@ def edit( data=data, session=self.session, token=self.auth_token, state=self._state ) - def _create_message(self, data): - state = _WebhookState(self, parent=self._state) + def _create_message(self, data, thread: Optional[Snowflake] = None): + state = _WebhookState(self, parent=self._state, thread=thread) # state may be artificial (unlikely at this point...) channel = self.channel channel_id = int(data["channel_id"]) @@ -1060,9 +1079,11 @@ def send( for f in params.files: f.close() if wait: - return self._create_message(data) + return self._create_message(data, thread=thread) - def fetch_message(self, id: int, /) -> SyncWebhookMessage: + def fetch_message( + self, id: int, /, *, thread: Optional[Snowflake] = None + ) -> SyncWebhookMessage: """Retrieves a single :class:`SyncWebhookMessage` owned by this webhook. .. versionadded:: 2.0 @@ -1074,6 +1095,10 @@ def fetch_message(self, id: int, /) -> SyncWebhookMessage: ---------- id: :class:`int` The message ID to look for. + thread: Optional[:class:`~disnake.abc.Snowflake`] + The thread the message is in, if any. + + .. versionadded:: 2.10 Raises ------ @@ -1100,8 +1125,9 @@ def fetch_message(self, id: int, /) -> SyncWebhookMessage: self.token, id, session=self.session, + thread_id=thread.id if thread else None, ) - return self._create_message(data) + return self._create_message(data, thread=thread) def edit_message( self, @@ -1114,6 +1140,7 @@ def edit_message( files: List[File] = MISSING, attachments: Optional[List[Attachment]] = MISSING, allowed_mentions: Optional[AllowedMentions] = None, + thread: Optional[Snowflake] = None, ) -> SyncWebhookMessage: """Edits a message owned by this webhook. @@ -1166,6 +1193,10 @@ def edit_message( allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. + thread: Optional[:class:`~disnake.abc.Snowflake`] + The thread the message is in, if any. + + .. versionadded:: 2.10 Raises ------ @@ -1186,7 +1217,7 @@ def edit_message( # if no attachment list was provided but we're uploading new files, # use current attachments as the base if attachments is MISSING and (file or files): - attachments = self.fetch_message(message_id).attachments + attachments = self.fetch_message(message_id, thread=thread).attachments previous_mentions: Optional[AllowedMentions] = getattr( self._state, "allowed_mentions", None @@ -1208,6 +1239,7 @@ def edit_message( self.token, message_id, session=self.session, + thread_id=thread.id if thread else None, payload=params.payload, multipart=params.multipart, files=params.files, @@ -1216,9 +1248,9 @@ def edit_message( if params.files: for f in params.files: f.close() - return self._create_message(data) + return self._create_message(data, thread=thread) - def delete_message(self, message_id: int, /) -> None: + def delete_message(self, message_id: int, /, *, thread: Optional[Snowflake] = None) -> None: """Deletes a message owned by this webhook. This is a lower level interface to :meth:`WebhookMessage.delete` in case @@ -1233,6 +1265,10 @@ def delete_message(self, message_id: int, /) -> None: ---------- message_id: :class:`int` The ID of the message to delete. + thread: Optional[:class:`~disnake.abc.Snowflake`] + The thread the message is in, if any. + + .. versionadded:: 2.10 Raises ------ @@ -1252,4 +1288,5 @@ def delete_message(self, message_id: int, /) -> None: self.token, message_id, session=self.session, + thread_id=thread.id if thread else None, ) From 00d59e82d9e1db3e41fc192c8652738ebf6c7e4f Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 14 Jul 2023 19:55:19 +0200 Subject: [PATCH 2/3] fix: set thread in webhook state correctly when using `thread_name` --- disnake/webhook/async_.py | 23 +++++++++++++++++------ disnake/webhook/sync.py | 14 ++++++++++---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/disnake/webhook/async_.py b/disnake/webhook/async_.py index 021817dd07..e79d43e78b 100644 --- a/disnake/webhook/async_.py +++ b/disnake/webhook/async_.py @@ -37,6 +37,7 @@ from ..http import Route, set_attachments, to_multipart, to_multipart_with_attachments from ..message import Message from ..mixins import Hashable +from ..object import Object from ..ui.action_row import MessageUIComponent, components_to_dict from ..user import BaseUser, User @@ -1440,19 +1441,30 @@ async def edit( return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state) - def _create_message(self, data, thread: Optional[Snowflake] = None): - state = _WebhookState(self, parent=self._state, thread=thread) - # state may be artificial (unlikely at this point...) + def _create_message( + self, data, *, thread: Optional[Snowflake] = None, thread_name: Optional[str] = None + ): channel_id = int(data["channel_id"]) - # if the channel ID does not match, a new thread was created + + # If channel IDs don't match, a new thread was most likely created; + # if the user passed a `thread_name`, assume this is the case and + # create a `thread` object for the state + if self.channel_id != channel_id and thread_name: + thread = Object(id=channel_id) + + state = _WebhookState(self, parent=self._state, thread=thread) + + # If the channel IDs don't match, the message was created in a thread if self.channel_id != channel_id: guild = self.guild msg_channel = guild and guild.get_channel_or_thread(channel_id) else: msg_channel = self.channel + if not msg_channel: # state may be artificial (unlikely at this point...) msg_channel = PartialMessageable(state=self._state, id=channel_id) # type: ignore + # state is artificial return WebhookMessage(data=data, state=state, channel=msg_channel) # type: ignore @@ -1710,7 +1722,6 @@ async def send( flags=flags, view=view, components=components, - # TODO: check message.edit for this thread_name=thread_name, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, @@ -1736,7 +1747,7 @@ async def send( msg = None if wait: - msg = self._create_message(data, thread=thread) + msg = self._create_message(data, thread=thread, thread_name=thread_name) if delete_after is not MISSING: await msg.delete(delay=delete_after) diff --git a/disnake/webhook/sync.py b/disnake/webhook/sync.py index 8a6efa2c6d..8190ece5bc 100644 --- a/disnake/webhook/sync.py +++ b/disnake/webhook/sync.py @@ -22,6 +22,7 @@ from ..flags import MessageFlags from ..http import Route from ..message import Message +from ..object import Object from .async_ import BaseWebhook, _WebhookState, handle_message_parameters __all__ = ( @@ -864,11 +865,16 @@ def edit( data=data, session=self.session, token=self.auth_token, state=self._state ) - def _create_message(self, data, thread: Optional[Snowflake] = None): + def _create_message( + self, data, *, thread: Optional[Snowflake] = None, thread_name: Optional[str] = None + ): + # see async webhook's _create_message for details + channel_id = int(data["channel_id"]) + if self.channel_id != channel_id and thread_name: + thread = Object(id=channel_id) + state = _WebhookState(self, parent=self._state, thread=thread) - # state may be artificial (unlikely at this point...) channel = self.channel - channel_id = int(data["channel_id"]) if not channel or self.channel_id != channel_id: channel = PartialMessageable(state=self._state, id=channel_id) # type: ignore # state is artificial @@ -1079,7 +1085,7 @@ def send( for f in params.files: f.close() if wait: - return self._create_message(data, thread=thread) + return self._create_message(data, thread=thread, thread_name=thread_name) def fetch_message( self, id: int, /, *, thread: Optional[Snowflake] = None From 8af0b763141312286edb08d0f87fb4c9b927200d Mon Sep 17 00:00:00 2001 From: shiftinv Date: Fri, 14 Jul 2023 19:58:44 +0200 Subject: [PATCH 3/3] docs: add changelog entry --- changelog/1077.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1077.feature.rst diff --git a/changelog/1077.feature.rst b/changelog/1077.feature.rst new file mode 100644 index 0000000000..2ba901b3f9 --- /dev/null +++ b/changelog/1077.feature.rst @@ -0,0 +1 @@ +Add support for threads in :meth:`Webhook.fetch_message`, :meth:`~Webhook.edit_message`, and :meth:`~Webhook.delete_message`, as well as their sync counterparts.