From d61d0ab8d30bf119a8df8fe37c67c72534db70d7 Mon Sep 17 00:00:00 2001 From: VincentRPS Date: Wed, 3 Jan 2024 19:58:38 +0800 Subject: [PATCH] fix: reviews, feat: BucketStorageProtocol --- discord/client.py | 5 +- discord/http.py | 19 ++++--- discord/rate_limiting.py | 113 +++++++++++++++++++++++++++++++++++++-- discord/state.py | 1 - 4 files changed, 122 insertions(+), 16 deletions(-) diff --git a/discord/client.py b/discord/client.py index 3ef5173ee4..fae80832b9 100644 --- a/discord/client.py +++ b/discord/client.py @@ -63,7 +63,6 @@ from .utils import MISSING from .voice_client import VoiceClient from .webhook import Webhook -from .webhook.async_ import AsyncWebhookAdapter, async_context from .widget import Widget if TYPE_CHECKING: @@ -199,7 +198,7 @@ class Client: To enable these events, this must be set to ``True``. Defaults to ``False``. .. versionadded:: 2.0 - bucket_storage_cls: :class:`type`[:class:`.rate_limiting.BucketStorage`] + bucket_storage_cls: :class:`type`[:class:`.rate_limiting.BucketStorageProtocol`] The class to use for storing rate limit buckets given by Discord. .. versionadded:: 2.5 @@ -260,7 +259,7 @@ def __init__( proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop, - maximum_rate_limit_wait_time=options.pop("maximum_rate_limit_time", -1), + maximum_rate_limit_time=options.pop("maximum_rate_limit_time", -1), ) self._handlers: dict[str, Callable] = {"ready": self._handle_ready} diff --git a/discord/http.py b/discord/http.py index 042e6d64f4..2ef827523f 100644 --- a/discord/http.py +++ b/discord/http.py @@ -44,7 +44,7 @@ NotFound, ) from .gateway import DiscordClientWebSocketResponse -from .rate_limiting import BucketStorage, DynamicBucket +from .rate_limiting import BucketStorageProtocol, DynamicBucket from .utils import MISSING, warn_deprecated _log = logging.getLogger(__name__) @@ -80,7 +80,6 @@ from .types.snowflake import Snowflake, SnowflakeList T = TypeVar("T") - BE = TypeVar("BE", bound=BaseException) Response = Coroutine[Any, Any, T] API_VERSION: int = 10 @@ -138,14 +137,14 @@ class HTTPClient: def __init__( self, - bucket_storage: BucketStorage, + bucket_storage: BucketStorageProtocol, connector: aiohttp.BaseConnector | None = None, *, proxy: str | None = None, proxy_auth: aiohttp.BasicAuth | None = None, loop: asyncio.AbstractEventLoop | None = None, unsync_clock: bool = True, - maximum_rate_limit_wait_time: int | float = -1, + maximum_rate_limit_time: int | float = -1, ) -> None: self.loop: asyncio.AbstractEventLoop = ( asyncio.get_event_loop() if loop is None else loop @@ -157,7 +156,7 @@ def __init__( self.proxy: str | None = proxy self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth self.use_clock: bool = not unsync_clock - self.maximum_rate_limit_wait_time = maximum_rate_limit_wait_time + self.maximum_rate_limit_time = maximum_rate_limit_time user_agent = ( "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" @@ -203,6 +202,10 @@ async def request( method = route.method url = route.url + if not self._rate_limit.ready: + await self._rate_limit.start() + self._rate_limit.ready = True + bucket = await self._rate_limit.get_or_create(bucket_id) # header creation @@ -310,12 +313,12 @@ async def request( is_global: bool = data.get("global", False) if ( - retry_after > self.maximum_rate_limit_wait_time - and self.maximum_rate_limit_wait_time != -1 + retry_after > self.maximum_rate_limit_time + and self.maximum_rate_limit_time != -1 ): raise HTTPException( response, - f"rate limit wait costed over maximum of {self.maximum_rate_limit_wait_time}", + f"Retrying rate limit would take longer than the maximum of {self.maximum_rate_limit_wait_time} seconds given", ) if is_global: diff --git a/discord/rate_limiting.py b/discord/rate_limiting.py index 2123feecba..4963841b0b 100644 --- a/discord/rate_limiting.py +++ b/discord/rate_limiting.py @@ -28,7 +28,9 @@ import gc import time from contextlib import asynccontextmanager -from typing import AsyncIterator, Literal, cast +from typing import AsyncIterator, Literal, Protocol, cast + +from .utils import MISSING from .errors import DiscordException @@ -46,6 +48,9 @@ class GlobalRateLimit: The concurrency to reset every `per` seconds. per: :class:`int` | :class:`float` Number of seconds to wait until resetting `concurrency`. + remaining: :class:`int` | MISSING + Number of available requests remaining. If the value of remaining + is larger than concurrency a `ValueError` will be raised. Attributes ---------- @@ -57,7 +62,12 @@ class GlobalRateLimit: Unix timestamp of when this class will next reset. """ - def __init__(self, concurrency: int, per: float | int) -> None: + def __init__( + self, + concurrency: int, + per: float | int, + remaining: int = MISSING + ) -> None: self.concurrency: int = concurrency self.per: float | int = per @@ -67,6 +77,9 @@ def __init__(self, concurrency: int, per: float | int) -> None: self.pending_reset: bool = False self.reset_at: int | float | None = None + if remaining is not MISSING: + raise ValueError("Given rate limit remaining value is larger than concurrency limit") + async def __aenter__(self) -> GlobalRateLimit: if not self.loop: self.loop = asyncio.get_running_loop() @@ -293,9 +306,14 @@ def _reset(self) -> None: self.release(self.limit) -class BucketStorage: +class BucketStorageProtocol(Protocol): """A customizable, optionally replacable storage medium for buckets. + Attributes + ---------- + ready: :class:`bool` + Whether the BucketStorage is ready. + Parameters ---------- concurrency: :class:`int` @@ -304,10 +322,97 @@ class BucketStorage: Number of seconds to wait until resetting `concurrency`. """ + per: int + concurrency: int + ready: bool + global_concurrency: GlobalRateLimit + + def __init__(self, per: int = 1, concurrency: int = 50) -> None: + ... + + async def start(self) -> None: + """An internal asynchronous function for BucketStorage + used for initiating certain libraries such as a database. + """ + + async def close(self) -> None: + """An internal asynchronous function for BucketStorage + used for closing connections to databases and other such. + """ + + async def append(self, id: str, bucket: Bucket) -> None: + """Append a permanent bucket. + + Parameters + ---------- + id: :class:`str` + This bucket's identifier. + bucket: :class:`.Bucket` + The bucket to append. + """ + + async def get(self, id: str) -> Bucket | None: + """Get a permanent bucket. + + Parameters + ---------- + id: :class:`str` + This bucket's identifier. + + Returns + ------- + :class:`.Bucket` or `None` + """ + + async def get_or_create(self, id: str) -> Bucket: + """Get or create a permanent bucket. + + Parameters + ---------- + id: :class:`str` + This bucket's identifier. + + Returns + ------- + :class:`.Bucket` + """ + + async def temp_bucket(self, id: str) -> DynamicBucket | None: + """Fetch a temporary bucket. + + Parameters + ---------- + id: :class:`str` + This bucket's identifier. + + Returns + ------- + :class:`.DynamicBucket` or `None` + """ + + async def push_temp_bucket(self, id: str, bucket: DynamicBucket) -> None: + """Push a temporary bucket to storage. + + Parameters + ---------- + id: :class:`str` + This bucket's identifier. + """ + + async def pop_temp_bucket(self, id: str) -> None: + """Pop a temporary bucket which *may* be in storage. + + Parameters + ---------- + id: :class:`str` + This bucket's identifier. + """ + +class BucketStorage(BucketStorageProtocol): def __init__(self, per: int = 1, concurrency: int = 50) -> None: self._buckets: dict[str, Bucket] = {} + self.ready = True self.global_concurrency = GlobalRateLimit(concurrency, per) - self.webhook_global_concurrency = GlobalRateLimit(30, 60) gc.callbacks.append(self._collect_buckets) diff --git a/discord/state.py b/discord/state.py index 70136de1eb..d222ba4518 100644 --- a/discord/state.py +++ b/discord/state.py @@ -61,7 +61,6 @@ from .message import Message from .object import Object from .partial_emoji import PartialEmoji -from .rate_limiting import BucketStorage from .raw_models import * from .role import Role from .scheduled_events import ScheduledEvent