Skip to content

Commit

Permalink
feat(webhook): improve thread support for webhook messages (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv authored Jan 20, 2024
1 parent d782f54 commit 4da720d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 32 deletions.
1 change: 1 addition & 0 deletions changelog/1077.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for threads in :meth:`Webhook.fetch_message`, :meth:`~Webhook.edit_message`, and :meth:`~Webhook.delete_message`, as well as their sync counterparts.
91 changes: 73 additions & 18 deletions disnake/webhook/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..http import Route, set_attachments, to_multipart, to_multipart_with_attachments
from ..message import Message
from ..mixins import Hashable
from ..object import Object
from ..ui.action_row import MessageUIComponent, components_to_dict
from ..user import BaseUser, User

Expand Down Expand Up @@ -293,6 +294,7 @@ def execute_webhook(
params = {"wait": int(wait)}
if thread_id:
params["thread_id"] = thread_id

route = Route(
"POST",
"/webhooks/{webhook_id}/{webhook_token}",
Expand All @@ -310,15 +312,20 @@ def get_webhook_message(
message_id: int,
*,
session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[MessagePayload]:
params: Dict[str, Any] = {}
if thread_id is not None:
params["thread_id"] = thread_id

route = Route(
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
return self.request(route, session, params=params)

def edit_webhook_message(
self,
Expand All @@ -330,15 +337,22 @@ def edit_webhook_message(
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
thread_id: Optional[int] = None,
) -> Response[Message]:
params: Dict[str, Any] = {}
if thread_id is not None:
params["thread_id"] = thread_id

route = Route(
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session, payload=payload, multipart=multipart, files=files)
return self.request(
route, session, payload=payload, multipart=multipart, files=files, params=params
)

def delete_webhook_message(
self,
Expand All @@ -347,15 +361,20 @@ def delete_webhook_message(
message_id: int,
*,
session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[None]:
params: Dict[str, Any] = {}
if thread_id is not None:
params["thread_id"] = thread_id

route = Route(
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
return self.request(route, session, params=params)

def fetch_webhook(
self,
Expand Down Expand Up @@ -691,10 +710,14 @@ def __getattr__(self, attr) -> NoReturn:


class _WebhookState(Generic[WebhookT]):
__slots__ = ("_parent", "_webhook")
__slots__ = ("_parent", "_webhook", "_thread")

def __init__(
self, webhook: WebhookT, parent: Optional[Union[ConnectionState, _WebhookState]]
self,
webhook: WebhookT,
parent: Optional[Union[ConnectionState, _WebhookState]],
*,
thread: Optional[Snowflake] = None,
) -> None:
self._webhook: WebhookT = webhook

Expand All @@ -704,6 +727,8 @@ def __init__(
else:
self._parent = parent

self._thread: Optional[Snowflake] = thread

def _get_guild(self, guild_id):
if self._parent is not None:
return self._parent._get_guild(guild_id)
Expand Down Expand Up @@ -861,6 +886,7 @@ async def edit(
view=view,
components=components,
allowed_mentions=allowed_mentions,
thread=self._state._thread,
)

async def delete(self, *, delay: Optional[float] = None) -> None:
Expand Down Expand Up @@ -888,13 +914,13 @@ async def delete(self, *, delay: Optional[float] = None) -> None:
async def inner_call(delay: float = delay) -> None:
await asyncio.sleep(delay)
try:
await self._state._webhook.delete_message(self.id)
await self._state._webhook.delete_message(self.id, thread=self._state._thread)
except HTTPException:
pass

asyncio.create_task(inner_call())
else:
await self._state._webhook.delete_message(self.id)
await self._state._webhook.delete_message(self.id, thread=self._state._thread)


class BaseWebhook(Hashable):
Expand Down Expand Up @@ -1422,19 +1448,30 @@ async def edit(

return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)

def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
def _create_message(
self, data, *, thread: Optional[Snowflake] = None, thread_name: Optional[str] = None
):
channel_id = int(data["channel_id"])
# if the channel ID does not match, a new thread was created

# If channel IDs don't match, a new thread was most likely created;
# if the user passed a `thread_name`, assume this is the case and
# create a `thread` object for the state
if self.channel_id != channel_id and thread_name:
thread = Object(id=channel_id)

state = _WebhookState(self, parent=self._state, thread=thread)

# If the channel IDs don't match, the message was created in a thread
if self.channel_id != channel_id:
guild = self.guild
msg_channel = guild and guild.get_channel_or_thread(channel_id)
else:
msg_channel = self.channel

if not msg_channel:
# state may be artificial (unlikely at this point...)
msg_channel = PartialMessageable(state=self._state, id=channel_id) # type: ignore

# state is artificial
return WebhookMessage(data=data, state=state, channel=msg_channel) # type: ignore

Expand Down Expand Up @@ -1588,7 +1625,7 @@ async def send(
.. versionadded:: 2.4
thread: :class:`~disnake.abc.Snowflake`
The thread to send this webhook to.
The thread to send this message to.
.. versionadded:: 2.0
Expand Down Expand Up @@ -1732,7 +1769,7 @@ async def send(

msg = None
if wait:
msg = self._create_message(data)
msg = self._create_message(data, thread=thread, thread_name=thread_name)
if delete_after is not MISSING:
await msg.delete(delay=delete_after)

Expand All @@ -1742,7 +1779,7 @@ async def send(

return msg

async def fetch_message(self, id: int) -> WebhookMessage:
async def fetch_message(self, id: int, *, thread: Optional[Snowflake] = None) -> WebhookMessage:
"""|coro|
Retrieves a single :class:`WebhookMessage` owned by this webhook.
Expand All @@ -1756,6 +1793,10 @@ async def fetch_message(self, id: int) -> WebhookMessage:
----------
id: :class:`int`
The message ID to look for.
thread: Optional[:class:`~disnake.abc.Snowflake`]
The thread the message is in, if any.
.. versionadded:: 2.10
Raises
------
Expand All @@ -1782,8 +1823,9 @@ async def fetch_message(self, id: int) -> WebhookMessage:
self.token,
id,
session=self.session,
thread_id=thread.id if thread else None,
)
return self._create_message(data)
return self._create_message(data, thread=thread)

async def edit_message(
self,
Expand All @@ -1798,6 +1840,7 @@ async def edit_message(
view: Optional[View] = MISSING,
components: Optional[Components[MessageUIComponent]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
thread: Optional[Snowflake] = None,
) -> WebhookMessage:
"""|coro|
Expand Down Expand Up @@ -1873,6 +1916,10 @@ async def edit_message(
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
thread: Optional[:class:`~disnake.abc.Snowflake`]
The thread the message is in, if any.
.. versionadded:: 2.10
Raises
------
Expand Down Expand Up @@ -1905,7 +1952,7 @@ async def edit_message(
# if no attachment list was provided but we're uploading new files,
# use current attachments as the base
if attachments is MISSING and (file or files):
attachments = (await self.fetch_message(message_id)).attachments
attachments = (await self.fetch_message(message_id, thread=thread)).attachments

previous_mentions: Optional[AllowedMentions] = getattr(
self._state, "allowed_mentions", None
Expand All @@ -1929,6 +1976,7 @@ async def edit_message(
self.token,
message_id,
session=self.session,
thread_id=thread.id if thread else None,
payload=params.payload,
multipart=params.multipart,
files=params.files,
Expand All @@ -1938,12 +1986,14 @@ async def edit_message(
for f in params.files:
f.close()

message = self._create_message(data)
message = self._create_message(data, thread=thread)
if view and not view.is_finished():
self._state.store_view(view, message_id)
return message

async def delete_message(self, message_id: int, /) -> None:
async def delete_message(
self, message_id: int, /, *, thread: Optional[Snowflake] = None
) -> None:
"""|coro|
Deletes a message owned by this webhook.
Expand All @@ -1960,6 +2010,10 @@ async def delete_message(self, message_id: int, /) -> None:
----------
message_id: :class:`int`
The ID of the message to delete.
thread: Optional[:class:`~disnake.abc.Snowflake`]
The thread the message is in, if any.
.. versionadded:: 2.10
Raises
------
Expand All @@ -1979,4 +2033,5 @@ async def delete_message(self, message_id: int, /) -> None:
self.token,
message_id,
session=self.session,
thread_id=thread.id if thread else None,
)
Loading

0 comments on commit 4da720d

Please sign in to comment.