Skip to content

Commit

Permalink
Merge branch 'master' into refactor/async
Browse files Browse the repository at this point in the history
Signed-off-by: lena <[email protected]>
  • Loading branch information
elenakrittik authored Nov 18, 2023
2 parents 94ceb0b + cd48c92 commit f38fed2
Show file tree
Hide file tree
Showing 56 changed files with 254 additions and 194 deletions.
1 change: 1 addition & 0 deletions changelog/1094.doc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add inherited attributes to :class:`TeamMember`, and fix :attr:`TeamMember.avatar` documentation.
1 change: 1 addition & 0 deletions changelog/1094.feature.0.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add :attr:`TeamMember.role`.
1 change: 1 addition & 0 deletions changelog/1094.feature.1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Update :meth:`Bot.is_owner <ext.commands.Bot.is_owner>` to take team member roles into account.
2 changes: 1 addition & 1 deletion disnake/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ async def _edit(
if p_id is not None and (parent := self.guild.get_channel(p_id)):
overwrites_payload = [c._asdict() for c in parent._overwrites]

if overwrites is not MISSING and overwrites is not None:
if overwrites not in (MISSING, None):
overwrites_payload = []
for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite):
Expand Down
2 changes: 1 addition & 1 deletion disnake/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def create_activity(
elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data:
activity = Spotify(**data)
else:
activity = Activity(**data)
activity = Activity(**data) # type: ignore

if isinstance(activity, (Activity, CustomActivity)) and activity.emoji and state:
activity.emoji._state = state
Expand Down
2 changes: 1 addition & 1 deletion disnake/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
AnyState = Union[ConnectionState, _WebhookState[BaseWebhook]]

AssetBytes = Union[bytes, "AssetMixin"]
AssetBytes = Union[utils._BytesLike, "AssetMixin"]

VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
Expand Down
2 changes: 1 addition & 1 deletion disnake/audit_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _transform_datetime(entry: AuditLogEntry, data: Optional[str]) -> Optional[d


def _transform_privacy_level(
entry: AuditLogEntry, data: int
entry: AuditLogEntry, data: Optional[int]
) -> Optional[Union[enums.StagePrivacyLevel, enums.GuildScheduledEventPrivacyLevel]]:
if data is None:
return None
Expand Down
12 changes: 6 additions & 6 deletions disnake/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ async def edit(
overwrites=overwrites,
flags=flags,
reason=reason,
**kwargs,
**kwargs, # type: ignore
)
if payload is not None:
# the payload will always be the proper channel payload
Expand Down Expand Up @@ -1628,7 +1628,7 @@ async def edit(
slowmode_delay=slowmode_delay,
flags=flags,
reason=reason,
**kwargs,
**kwargs, # type: ignore
)
if payload is not None:
# the payload will always be the proper channel payload
Expand Down Expand Up @@ -2453,7 +2453,7 @@ async def edit(
flags=flags,
slowmode_delay=slowmode_delay,
reason=reason,
**kwargs,
**kwargs, # type: ignore
)
if payload is not None:
# the payload will always be the proper channel payload
Expand Down Expand Up @@ -2946,7 +2946,7 @@ async def edit(
overwrites=overwrites,
flags=flags,
reason=reason,
**kwargs,
**kwargs, # type: ignore
)
if payload is not None:
# the payload will always be the proper channel payload
Expand Down Expand Up @@ -3619,7 +3619,7 @@ async def edit(
default_sort_order=default_sort_order,
default_layout=default_layout,
reason=reason,
**kwargs,
**kwargs, # type: ignore
)
if payload is not None:
# the payload will always be the proper channel payload
Expand Down Expand Up @@ -3994,7 +3994,7 @@ async def create_thread(
stickers=stickers,
)

if auto_archive_duration is not None:
if auto_archive_duration not in (MISSING, None):
auto_archive_duration = cast(
"ThreadArchiveDurationLiteral", try_enum_to_int(auto_archive_duration)
)
Expand Down
32 changes: 23 additions & 9 deletions disnake/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Optional,
Sequence,
Tuple,
TypedDict,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -78,7 +79,7 @@
from .widget import Widget

if TYPE_CHECKING:
from typing_extensions import Never
from typing_extensions import Never, NotRequired

from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime
from .app_commands import APIApplicationCommand
Expand Down Expand Up @@ -173,6 +174,17 @@ class GatewayParams(NamedTuple):
zlib: bool = True


# used for typing the ws parameter dict in the connect() loop
class _WebSocketParams(TypedDict):
initial: bool
shard_id: Optional[int]
gateway: Optional[str]

sequence: NotRequired[Optional[int]]
resume: NotRequired[bool]
session: NotRequired[Optional[str]]


class Client:
"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
Expand Down Expand Up @@ -1082,7 +1094,7 @@ async def connect(
if not ignore_session_start_limit and self.session_start_limit.remaining == 0:
raise SessionStartLimitReached(self.session_start_limit)

ws_params = {
ws_params: _WebSocketParams = {
"initial": True,
"shard_id": self.shard_id,
"gateway": initial_gateway,
Expand All @@ -1106,6 +1118,7 @@ async def connect(

while True:
await self.ws.poll_event()

except ReconnectWebSocket as e:
_log.info("Got a request to %s the websocket.", e.op)
self.dispatch("disconnect")
Expand All @@ -1118,6 +1131,7 @@ async def connect(
gateway=self.ws.resume_gateway if e.resume else initial_gateway,
)
continue

except (
OSError,
HTTPException,
Expand Down Expand Up @@ -1198,7 +1212,8 @@ async def close(self) -> None:
# if an error happens during disconnects, disregard it.
pass

if self.ws is not None and self.ws.open:
# can be None if not connected
if self.ws is not None and self.ws.open: # pyright: ignore[reportUnnecessaryComparison]
await self.ws.close(code=1000)

await self.http.close()
Expand Down Expand Up @@ -1849,16 +1864,15 @@ async def change_presence(

await self.ws.change_presence(activity=activity, status=status_str)

activities = () if activity is None else (activity,)
for guild in self._connection.guilds:
me = guild.me
if me is None:
if me is None: # pyright: ignore[reportUnnecessaryComparison]
# may happen if guild is unavailable
continue

if activity is not None:
me.activities = (activity,) # type: ignore
else:
me.activities = ()

# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
me.activities = activities # type: ignore
me.status = status

# Guild stuff
Expand Down
48 changes: 32 additions & 16 deletions disnake/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Dict,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
Expand All @@ -22,11 +23,12 @@
from .utils import MISSING, assert_never, get_slots

if TYPE_CHECKING:
from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from .emoji import Emoji
from .types.components import (
ActionRow as ActionRowPayload,
AnySelectMenu as AnySelectMenuPayload,
BaseSelectMenu as BaseSelectMenuPayload,
ButtonComponent as ButtonComponentPayload,
ChannelSelectMenu as ChannelSelectMenuPayload,
Expand Down Expand Up @@ -63,12 +65,16 @@
"MentionableSelectMenu",
"ChannelSelectMenu",
]
MessageComponent = Union["Button", "AnySelectMenu"]

if TYPE_CHECKING: # TODO: remove when we add modal select support
from typing_extensions import TypeAlias
SelectMenuType = Literal[
ComponentType.string_select,
ComponentType.user_select,
ComponentType.role_select,
ComponentType.mentionable_select,
ComponentType.channel_select,
]

# ModalComponent = Union["TextInput", "AnySelectMenu"]
MessageComponent = Union["Button", "AnySelectMenu"]
ModalComponent: TypeAlias = "TextInput"

NestedComponent = Union[MessageComponent, ModalComponent]
Expand Down Expand Up @@ -131,8 +137,6 @@ class ActionRow(Component, Generic[ComponentT]):
Attributes
----------
type: :class:`ComponentType`
The type of component.
children: List[Union[:class:`Button`, :class:`BaseSelectMenu`, :class:`TextInput`]]
The children components that this holds, if any.
"""
Expand All @@ -142,10 +146,9 @@ class ActionRow(Component, Generic[ComponentT]):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__

def __init__(self, data: ActionRowPayload) -> None:
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.children: List[ComponentT] = [
_component_factory(d) for d in data.get("components", [])
]
self.type: Literal[ComponentType.action_row] = ComponentType.action_row
children = [_component_factory(d) for d in data.get("components", [])]
self.children: List[ComponentT] = children # type: ignore

def to_dict(self) -> ActionRowPayload:
return {
Expand Down Expand Up @@ -195,7 +198,7 @@ class Button(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__

def __init__(self, data: ButtonComponentPayload) -> None:
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.type: Literal[ComponentType.button] = ComponentType.button
self.style: ButtonStyle = try_enum(ButtonStyle, data["style"])
self.custom_id: Optional[str] = data.get("custom_id")
self.url: Optional[str] = data.get("url")
Expand All @@ -209,7 +212,7 @@ def __init__(self, data: ButtonComponentPayload) -> None:

def to_dict(self) -> ButtonComponentPayload:
payload: ButtonComponentPayload = {
"type": 2,
"type": self.type.value,
"style": self.style.value,
"disabled": self.disabled,
}
Expand Down Expand Up @@ -273,8 +276,13 @@ class BaseSelectMenu(Component):

__repr_info__: ClassVar[Tuple[str, ...]] = __slots__

def __init__(self, data: BaseSelectMenuPayload) -> None:
self.type: ComponentType = try_enum(ComponentType, data["type"])
# n.b: ideally this would be `BaseSelectMenuPayload`,
# but pyright made TypedDict keys invariant and doesn't
# fully support readonly items yet (which would help avoid this)
def __init__(self, data: AnySelectMenuPayload) -> None:
component_type = try_enum(ComponentType, data["type"])
self.type: SelectMenuType = component_type # type: ignore

self.custom_id: str = data["custom_id"]
self.placeholder: Optional[str] = data.get("placeholder")
self.min_values: int = data.get("min_values", 1)
Expand Down Expand Up @@ -329,6 +337,7 @@ class StringSelectMenu(BaseSelectMenu):
__slots__: Tuple[str, ...] = ("options",)

__repr_info__: ClassVar[Tuple[str, ...]] = BaseSelectMenu.__repr_info__ + __slots__
type: Literal[ComponentType.string_select]

def __init__(self, data: StringSelectMenuPayload) -> None:
super().__init__(data)
Expand Down Expand Up @@ -372,6 +381,8 @@ class UserSelectMenu(BaseSelectMenu):

__slots__: Tuple[str, ...] = ()

type: Literal[ComponentType.user_select]

if TYPE_CHECKING:

def to_dict(self) -> UserSelectMenuPayload:
Expand Down Expand Up @@ -405,6 +416,8 @@ class RoleSelectMenu(BaseSelectMenu):

__slots__: Tuple[str, ...] = ()

type: Literal[ComponentType.role_select]

if TYPE_CHECKING:

def to_dict(self) -> RoleSelectMenuPayload:
Expand Down Expand Up @@ -438,6 +451,8 @@ class MentionableSelectMenu(BaseSelectMenu):

__slots__: Tuple[str, ...] = ()

type: Literal[ComponentType.mentionable_select]

if TYPE_CHECKING:

def to_dict(self) -> MentionableSelectMenuPayload:
Expand Down Expand Up @@ -475,6 +490,7 @@ class ChannelSelectMenu(BaseSelectMenu):
__slots__: Tuple[str, ...] = ("channel_types",)

__repr_info__: ClassVar[Tuple[str, ...]] = BaseSelectMenu.__repr_info__ + __slots__
type: Literal[ComponentType.channel_select]

def __init__(self, data: ChannelSelectMenuPayload) -> None:
super().__init__(data)
Expand Down Expand Up @@ -643,7 +659,7 @@ class TextInput(Component):
def __init__(self, data: TextInputPayload) -> None:
style = data.get("style", TextInputStyle.short.value)

self.type: ComponentType = try_enum(ComponentType, data["type"])
self.type: Literal[ComponentType.text_input] = ComponentType.text_input
self.custom_id: str = data["custom_id"]
self.style: TextInputStyle = try_enum(TextInputStyle, style)
self.label: Optional[str] = data.get("label")
Expand Down
2 changes: 1 addition & 1 deletion disnake/emoji.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def roles(self) -> List[Role]:
and count towards a separate limit of 25 emojis.
"""
guild = self.guild
if guild is None:
if guild is None: # pyright: ignore[reportUnnecessaryComparison]
return []

return [role for role in guild.roles if self._roles.has(role.id)]
Expand Down
14 changes: 12 additions & 2 deletions disnake/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"ActivityType",
"NotificationLevel",
"TeamMembershipState",
"TeamMemberRole",
"WebhookType",
"ExpireBehaviour",
"ExpireBehavior",
Expand Down Expand Up @@ -466,7 +467,7 @@ def category(self) -> Optional[AuditLogActionCategory]:
@property
def target_type(self) -> Optional[str]:
v = self.value
if v == -1:
if v == -1: # pyright: ignore[reportUnnecessaryComparison]
return "all"
elif v < 10:
return "guild"
Expand Down Expand Up @@ -551,6 +552,15 @@ class TeamMembershipState(Enum):
accepted = 2


class TeamMemberRole(Enum):
admin = "admin"
developer = "developer"
read_only = "read_only"

def __str__(self) -> str:
return self.name


class WebhookType(Enum):
incoming = 1
channel_follower = 2
Expand Down Expand Up @@ -627,7 +637,7 @@ class ComponentType(Enum):
action_row = 1
button = 2
string_select = 3
select = string_select # backwards compatibility
select = 3 # backwards compatibility
text_input = 4
user_select = 5
role_select = 6
Expand Down
2 changes: 1 addition & 1 deletion disnake/ext/commands/base_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _prepare_cooldowns(self, inter: ApplicationCommandInteraction) -> None:
dt = inter.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(inter, current) # type: ignore
if bucket is not None:
if bucket is not None: # pyright: ignore[reportUnnecessaryComparison]
retry_after = bucket.update_rate_limit(current)
if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
Expand Down
Loading

0 comments on commit f38fed2

Please sign in to comment.