Skip to content

Commit

Permalink
fix: reviews, feat: BucketStorageProtocol
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentRPS committed Jan 3, 2024
1 parent 38feea0 commit d61d0ab
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 16 deletions.
5 changes: 2 additions & 3 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from .utils import MISSING
from .voice_client import VoiceClient
from .webhook import Webhook
from .webhook.async_ import AsyncWebhookAdapter, async_context
from .widget import Widget

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,7 +198,7 @@ class Client:
To enable these events, this must be set to ``True``. Defaults to ``False``.
.. versionadded:: 2.0
bucket_storage_cls: :class:`type`[:class:`.rate_limiting.BucketStorage`]
bucket_storage_cls: :class:`type`[:class:`.rate_limiting.BucketStorageProtocol`]
The class to use for storing rate limit buckets given by Discord.
.. versionadded:: 2.5
Expand Down Expand Up @@ -260,7 +259,7 @@ def __init__(
proxy_auth=proxy_auth,
unsync_clock=unsync_clock,
loop=self.loop,
maximum_rate_limit_wait_time=options.pop("maximum_rate_limit_time", -1),
maximum_rate_limit_time=options.pop("maximum_rate_limit_time", -1),
)

self._handlers: dict[str, Callable] = {"ready": self._handle_ready}
Expand Down
19 changes: 11 additions & 8 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
NotFound,
)
from .gateway import DiscordClientWebSocketResponse
from .rate_limiting import BucketStorage, DynamicBucket
from .rate_limiting import BucketStorageProtocol, DynamicBucket
from .utils import MISSING, warn_deprecated

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,7 +80,6 @@
from .types.snowflake import Snowflake, SnowflakeList

T = TypeVar("T")
BE = TypeVar("BE", bound=BaseException)
Response = Coroutine[Any, Any, T]

API_VERSION: int = 10
Expand Down Expand Up @@ -138,14 +137,14 @@ class HTTPClient:

def __init__(
self,
bucket_storage: BucketStorage,
bucket_storage: BucketStorageProtocol,
connector: aiohttp.BaseConnector | None = None,
*,
proxy: str | None = None,
proxy_auth: aiohttp.BasicAuth | None = None,
loop: asyncio.AbstractEventLoop | None = None,
unsync_clock: bool = True,
maximum_rate_limit_wait_time: int | float = -1,
maximum_rate_limit_time: int | float = -1,
) -> None:
self.loop: asyncio.AbstractEventLoop = (
asyncio.get_event_loop() if loop is None else loop
Expand All @@ -157,7 +156,7 @@ def __init__(
self.proxy: str | None = proxy
self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth
self.use_clock: bool = not unsync_clock
self.maximum_rate_limit_wait_time = maximum_rate_limit_wait_time
self.maximum_rate_limit_time = maximum_rate_limit_time

user_agent = (
"DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
Expand Down Expand Up @@ -203,6 +202,10 @@ async def request(
method = route.method
url = route.url

if not self._rate_limit.ready:
await self._rate_limit.start()
self._rate_limit.ready = True

bucket = await self._rate_limit.get_or_create(bucket_id)

# header creation
Expand Down Expand Up @@ -310,12 +313,12 @@ async def request(
is_global: bool = data.get("global", False)

if (
retry_after > self.maximum_rate_limit_wait_time
and self.maximum_rate_limit_wait_time != -1
retry_after > self.maximum_rate_limit_time
and self.maximum_rate_limit_time != -1
):
raise HTTPException(
response,
f"rate limit wait costed over maximum of {self.maximum_rate_limit_wait_time}",
f"Retrying rate limit would take longer than the maximum of {self.maximum_rate_limit_wait_time} seconds given",
)

if is_global:
Expand Down
113 changes: 109 additions & 4 deletions discord/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import gc
import time
from contextlib import asynccontextmanager
from typing import AsyncIterator, Literal, cast
from typing import AsyncIterator, Literal, Protocol, cast

from .utils import MISSING

from .errors import DiscordException

Expand All @@ -46,6 +48,9 @@ class GlobalRateLimit:
The concurrency to reset every `per` seconds.
per: :class:`int` | :class:`float`
Number of seconds to wait until resetting `concurrency`.
remaining: :class:`int` | MISSING
Number of available requests remaining. If the value of remaining
is larger than concurrency a `ValueError` will be raised.
Attributes
----------
Expand All @@ -57,7 +62,12 @@ class GlobalRateLimit:
Unix timestamp of when this class will next reset.
"""

def __init__(self, concurrency: int, per: float | int) -> None:
def __init__(
self,
concurrency: int,
per: float | int,
remaining: int = MISSING
) -> None:
self.concurrency: int = concurrency
self.per: float | int = per

Expand All @@ -67,6 +77,9 @@ def __init__(self, concurrency: int, per: float | int) -> None:
self.pending_reset: bool = False
self.reset_at: int | float | None = None

if remaining is not MISSING:
raise ValueError("Given rate limit remaining value is larger than concurrency limit")

async def __aenter__(self) -> GlobalRateLimit:
if not self.loop:
self.loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -293,9 +306,14 @@ def _reset(self) -> None:
self.release(self.limit)


class BucketStorage:
class BucketStorageProtocol(Protocol):
"""A customizable, optionally replacable storage medium for buckets.
Attributes
----------
ready: :class:`bool`
Whether the BucketStorage is ready.
Parameters
----------
concurrency: :class:`int`
Expand All @@ -304,10 +322,97 @@ class BucketStorage:
Number of seconds to wait until resetting `concurrency`.
"""

per: int
concurrency: int
ready: bool
global_concurrency: GlobalRateLimit

def __init__(self, per: int = 1, concurrency: int = 50) -> None:
...

async def start(self) -> None:
"""An internal asynchronous function for BucketStorage
used for initiating certain libraries such as a database.
"""

async def close(self) -> None:
"""An internal asynchronous function for BucketStorage
used for closing connections to databases and other such.
"""

async def append(self, id: str, bucket: Bucket) -> None:
"""Append a permanent bucket.
Parameters
----------
id: :class:`str`
This bucket's identifier.
bucket: :class:`.Bucket`
The bucket to append.
"""

async def get(self, id: str) -> Bucket | None:
"""Get a permanent bucket.
Parameters
----------
id: :class:`str`
This bucket's identifier.
Returns
-------
:class:`.Bucket` or `None`
"""

async def get_or_create(self, id: str) -> Bucket:
"""Get or create a permanent bucket.
Parameters
----------
id: :class:`str`
This bucket's identifier.
Returns
-------
:class:`.Bucket`
"""

async def temp_bucket(self, id: str) -> DynamicBucket | None:
"""Fetch a temporary bucket.
Parameters
----------
id: :class:`str`
This bucket's identifier.
Returns
-------
:class:`.DynamicBucket` or `None`
"""

async def push_temp_bucket(self, id: str, bucket: DynamicBucket) -> None:
"""Push a temporary bucket to storage.
Parameters
----------
id: :class:`str`
This bucket's identifier.
"""

async def pop_temp_bucket(self, id: str) -> None:
"""Pop a temporary bucket which *may* be in storage.
Parameters
----------
id: :class:`str`
This bucket's identifier.
"""

class BucketStorage(BucketStorageProtocol):
def __init__(self, per: int = 1, concurrency: int = 50) -> None:
self._buckets: dict[str, Bucket] = {}
self.ready = True
self.global_concurrency = GlobalRateLimit(concurrency, per)
self.webhook_global_concurrency = GlobalRateLimit(30, 60)

gc.callbacks.append(self._collect_buckets)

Expand Down
1 change: 0 additions & 1 deletion discord/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from .message import Message
from .object import Object
from .partial_emoji import PartialEmoji
from .rate_limiting import BucketStorage
from .raw_models import *
from .role import Role
from .scheduled_events import ScheduledEvent
Expand Down

0 comments on commit d61d0ab

Please sign in to comment.