From 11c2059375d21808ef97dfa4cfa107831716a9c8 Mon Sep 17 00:00:00 2001 From: squiddy Date: Sun, 7 Apr 2024 22:17:38 +0800 Subject: [PATCH] better typing --- bot/src/cogs/github_auth.py | 6 +++--- bot/src/cogs/json_cache.py | 12 +++++++----- bot/src/cogs/ms_auth.py | 8 ++++---- bot/src/cogs/ui_helper.py | 6 +++--- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/bot/src/cogs/github_auth.py b/bot/src/cogs/github_auth.py index 1e09e31..06dabe1 100644 --- a/bot/src/cogs/github_auth.py +++ b/bot/src/cogs/github_auth.py @@ -29,7 +29,7 @@ def __init__(self, bot: Bot, cache: Cache, json_cache: JSONCache) -> None: self.bot = bot self.cache = cache - self.github_auth_flows: MutableMapping[str, Tuple[int, int]] = json_cache.register_cache( + self.github_auth_flows = json_cache.register_cache( "github_auth_flows", self.prune_auth_flows ) # state -> timestamp, discord id @@ -109,7 +109,7 @@ async def on_gh_auth_response(self, data) -> Union[str, Tuple[str, int]]: 500, ) - del self.github_auth_flows[params["state"]] + self.github_auth_flows.pop(params["state"]) # get their name github_user = Github(response.json()["access_token"]).get_user() @@ -137,7 +137,7 @@ def prune_auth_flows(self, github_auth_flows: MutableMapping[str, Tuple[int, int # wrap in list to create a copy of items (we modify the dict in the loop) for key, auth_flow_data in list(current_auth_flows.items()): if current_time - auth_flow_data[0] >= 86400: - del github_auth_flows[key] + github_auth_flows.pop(key) __all__ = ["GithubAuth"] diff --git a/bot/src/cogs/json_cache.py b/bot/src/cogs/json_cache.py index cf0ccd7..96f6684 100644 --- a/bot/src/cogs/json_cache.py +++ b/bot/src/cogs/json_cache.py @@ -1,12 +1,14 @@ import logging -from typing import Any, Callable, MutableMapping, Optional, Tuple +from typing import Any, Callable, MutableMapping, Optional, Tuple, TypeVar import orjson from nextcord.ext import tasks from nextcord.ext.commands import Bot, Cog logger = logging.getLogger(__name__) -SaveCallback = Callable[[MutableMapping[str, Any]], None] + +DataType = TypeVar('DataType') +SaveCallback = Callable[[MutableMapping[str, DataType]], None] class JSONCache(Cog): @@ -19,8 +21,8 @@ def __init__(self, bot: Bot) -> None: self.json_caches: MutableMapping[str, Tuple[SaveCallback, MutableMapping[str, Any]]] = {} def register_cache( - self, cache_name: str, do_before_save: Optional[SaveCallback] = None - ) -> MutableMapping[str, Any]: + self, cache_name: str, do_before_save: Optional[SaveCallback[DataType]] = None + ) -> MutableMapping[str, DataType]: if not do_before_save: _do_before_save: SaveCallback = lambda _: None else: @@ -32,7 +34,7 @@ def register_cache( except FileNotFoundError: data = b"{}" - cache: MutableMapping[str, Any] = orjson.loads(data) + cache: MutableMapping[str, DataType] = orjson.loads(data) logger.info(f"Loaded {len(cache)} records in {cache_name}.json") diff --git a/bot/src/cogs/ms_auth.py b/bot/src/cogs/ms_auth.py index 1cd4e69..ced70a6 100644 --- a/bot/src/cogs/ms_auth.py +++ b/bot/src/cogs/ms_auth.py @@ -38,7 +38,7 @@ def __init__(self, bot: Bot, cache: Cache, ui_helper: UIHelper, json_cache: JSON authority=f"https://login.microsoftonline.com/{config.ms_auth_tenant_id}", ) - self.auth_flows: MutableMapping[str, Tuple[int, int, Any]] = json_cache.register_cache( + self.auth_flows = json_cache.register_cache( "auth_flows", self.prune_auth_flows ) # state -> timestamp, discord id, flow @@ -151,7 +151,7 @@ def get_ms_auth_link(self, member_id: int) -> str: auth_flow = self.application.initiate_auth_code_flow( scopes=["User.Read"], - redirect_uri=f"{config.ms_auth_redirect_domain}", + redirect_uri=config.ms_auth_redirect_domain, state=state, response_mode="form_post", ) @@ -179,7 +179,7 @@ async def on_ms_auth_response(self, data) -> Union[str, Tuple[str, int]]: 500, ) - del self.auth_flows[params["state"]] + self.auth_flows.pop(params["state"]) # get their email and name user_data = requests.get( @@ -324,7 +324,7 @@ def prune_auth_flows(self, auth_flows: MutableMapping[str, Tuple[int, int, Any]] # wrap in list to create a copy of items (we modify the dict in the loop) for key, auth_flow_data in list(current_auth_flows.items()): if current_time - auth_flow_data[0] >= 86400: - del auth_flows[key] + auth_flows.pop(key) __all__ = ["MSAuth"] diff --git a/bot/src/cogs/ui_helper.py b/bot/src/cogs/ui_helper.py index 280e9a3..68e7111 100644 --- a/bot/src/cogs/ui_helper.py +++ b/bot/src/cogs/ui_helper.py @@ -87,7 +87,7 @@ async def on_message(self, message: Message) -> None: self.buttons[str(message.id)] = [] self.buttons[str(message.id)].append((component.custom_id, *self.pending[component.custom_id])) - del self.pending[component.custom_id] + self.pending.pop(component.custom_id) @Cog.listener() async def on_raw_message_delete(self, payload: RawMessageDeleteEvent) -> None: @@ -151,11 +151,11 @@ def filter_button_id(button: Tuple[str, str, Collection[Any]]) -> bool: continue self.buttons[message_id].append((button_id, *self.pending[button_id])) - del self.pending[button_id] + self.pending.pop(button_id) # if the length is 0, remove the message from the buttons dict if len(self.buttons[message_id]) == 0: - del self.buttons[message_id] + self.buttons.pop(message_id) @Cog.listener() async def on_interaction(self, interaction: Interaction) -> None: