Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(webhook): improve thread support for webhook messages #1077

Merged
merged 4 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1077.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for threads in :meth:`Webhook.fetch_message`, :meth:`~Webhook.edit_message`, and :meth:`~Webhook.delete_message`, as well as their sync counterparts.
91 changes: 73 additions & 18 deletions disnake/webhook/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..http import Route, set_attachments, to_multipart, to_multipart_with_attachments
from ..message import Message
from ..mixins import Hashable
from ..object import Object
from ..ui.action_row import MessageUIComponent, components_to_dict
from ..user import BaseUser, User

Expand Down Expand Up @@ -293,6 +294,7 @@ def execute_webhook(
params = {"wait": int(wait)}
if thread_id:
params["thread_id"] = thread_id

route = Route(
"POST",
"/webhooks/{webhook_id}/{webhook_token}",
Expand All @@ -310,15 +312,20 @@ def get_webhook_message(
message_id: int,
*,
session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[MessagePayload]:
params: Dict[str, Any] = {}
if thread_id is not None:
params["thread_id"] = thread_id

route = Route(
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
return self.request(route, session, params=params)

def edit_webhook_message(
self,
Expand All @@ -330,15 +337,22 @@ def edit_webhook_message(
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
thread_id: Optional[int] = None,
) -> Response[Message]:
params: Dict[str, Any] = {}
if thread_id is not None:
params["thread_id"] = thread_id

route = Route(
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session, payload=payload, multipart=multipart, files=files)
return self.request(
route, session, payload=payload, multipart=multipart, files=files, params=params
)

def delete_webhook_message(
self,
Expand All @@ -347,15 +361,20 @@ def delete_webhook_message(
message_id: int,
*,
session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[None]:
params: Dict[str, Any] = {}
if thread_id is not None:
params["thread_id"] = thread_id

route = Route(
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
return self.request(route, session, params=params)

def fetch_webhook(
self,
Expand Down Expand Up @@ -691,10 +710,14 @@ def __getattr__(self, attr) -> NoReturn:


class _WebhookState(Generic[WebhookT]):
__slots__ = ("_parent", "_webhook")
__slots__ = ("_parent", "_webhook", "_thread")

def __init__(
self, webhook: WebhookT, parent: Optional[Union[ConnectionState, _WebhookState]]
self,
webhook: WebhookT,
parent: Optional[Union[ConnectionState, _WebhookState]],
*,
thread: Optional[Snowflake] = None,
) -> None:
self._webhook: WebhookT = webhook

Expand All @@ -704,6 +727,8 @@ def __init__(
else:
self._parent = parent

self._thread: Optional[Snowflake] = thread

def _get_guild(self, guild_id):
if self._parent is not None:
return self._parent._get_guild(guild_id)
Expand Down Expand Up @@ -861,6 +886,7 @@ async def edit(
view=view,
components=components,
allowed_mentions=allowed_mentions,
thread=self._state._thread,
)

async def delete(self, *, delay: Optional[float] = None) -> None:
Expand Down Expand Up @@ -888,13 +914,13 @@ async def delete(self, *, delay: Optional[float] = None) -> None:
async def inner_call(delay: float = delay) -> None:
await asyncio.sleep(delay)
try:
await self._state._webhook.delete_message(self.id)
await self._state._webhook.delete_message(self.id, thread=self._state._thread)
except HTTPException:
pass

asyncio.create_task(inner_call())
else:
await self._state._webhook.delete_message(self.id)
await self._state._webhook.delete_message(self.id, thread=self._state._thread)


class BaseWebhook(Hashable):
Expand Down Expand Up @@ -1422,19 +1448,30 @@ async def edit(

return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)

def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
def _create_message(
self, data, *, thread: Optional[Snowflake] = None, thread_name: Optional[str] = None
):
channel_id = int(data["channel_id"])
# if the channel ID does not match, a new thread was created

# If channel IDs don't match, a new thread was most likely created;
# if the user passed a `thread_name`, assume this is the case and
# create a `thread` object for the state
if self.channel_id != channel_id and thread_name:
thread = Object(id=channel_id)

state = _WebhookState(self, parent=self._state, thread=thread)

# If the channel IDs don't match, the message was created in a thread
if self.channel_id != channel_id:
guild = self.guild
msg_channel = guild and guild.get_channel_or_thread(channel_id)
else:
msg_channel = self.channel

if not msg_channel:
# state may be artificial (unlikely at this point...)
msg_channel = PartialMessageable(state=self._state, id=channel_id) # type: ignore

# state is artificial
return WebhookMessage(data=data, state=state, channel=msg_channel) # type: ignore

Expand Down Expand Up @@ -1588,7 +1625,7 @@ async def send(
.. versionadded:: 2.4

thread: :class:`~disnake.abc.Snowflake`
The thread to send this webhook to.
The thread to send this message to.

.. versionadded:: 2.0

Expand Down Expand Up @@ -1732,7 +1769,7 @@ async def send(

msg = None
if wait:
msg = self._create_message(data)
msg = self._create_message(data, thread=thread, thread_name=thread_name)
if delete_after is not MISSING:
await msg.delete(delay=delete_after)

Expand All @@ -1742,7 +1779,7 @@ async def send(

return msg

async def fetch_message(self, id: int) -> WebhookMessage:
async def fetch_message(self, id: int, *, thread: Optional[Snowflake] = None) -> WebhookMessage:
"""|coro|

Retrieves a single :class:`WebhookMessage` owned by this webhook.
Expand All @@ -1756,6 +1793,10 @@ async def fetch_message(self, id: int) -> WebhookMessage:
----------
id: :class:`int`
The message ID to look for.
thread: Optional[:class:`~disnake.abc.Snowflake`]
The thread the message is in, if any.

.. versionadded:: 2.10

Raises
------
Expand All @@ -1782,8 +1823,9 @@ async def fetch_message(self, id: int) -> WebhookMessage:
self.token,
id,
session=self.session,
thread_id=thread.id if thread else None,
)
return self._create_message(data)
return self._create_message(data, thread=thread)

async def edit_message(
self,
Expand All @@ -1798,6 +1840,7 @@ async def edit_message(
view: Optional[View] = MISSING,
components: Optional[Components[MessageUIComponent]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
thread: Optional[Snowflake] = None,
) -> WebhookMessage:
"""|coro|

Expand Down Expand Up @@ -1873,6 +1916,10 @@ async def edit_message(
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
thread: Optional[:class:`~disnake.abc.Snowflake`]
The thread the message is in, if any.

.. versionadded:: 2.10

Raises
------
Expand Down Expand Up @@ -1905,7 +1952,7 @@ async def edit_message(
# if no attachment list was provided but we're uploading new files,
# use current attachments as the base
if attachments is MISSING and (file or files):
attachments = (await self.fetch_message(message_id)).attachments
attachments = (await self.fetch_message(message_id, thread=thread)).attachments

previous_mentions: Optional[AllowedMentions] = getattr(
self._state, "allowed_mentions", None
Expand All @@ -1929,6 +1976,7 @@ async def edit_message(
self.token,
message_id,
session=self.session,
thread_id=thread.id if thread else None,
payload=params.payload,
multipart=params.multipart,
files=params.files,
Expand All @@ -1938,12 +1986,14 @@ async def edit_message(
for f in params.files:
f.close()

message = self._create_message(data)
message = self._create_message(data, thread=thread)
if view and not view.is_finished():
self._state.store_view(view, message_id)
return message

async def delete_message(self, message_id: int, /) -> None:
async def delete_message(
self, message_id: int, /, *, thread: Optional[Snowflake] = None
) -> None:
"""|coro|

Deletes a message owned by this webhook.
Expand All @@ -1960,6 +2010,10 @@ async def delete_message(self, message_id: int, /) -> None:
----------
message_id: :class:`int`
The ID of the message to delete.
thread: Optional[:class:`~disnake.abc.Snowflake`]
The thread the message is in, if any.

.. versionadded:: 2.10

Raises
------
Expand All @@ -1979,4 +2033,5 @@ async def delete_message(self, message_id: int, /) -> None:
self.token,
message_id,
session=self.session,
thread_id=thread.id if thread else None,
)
Loading