Skip to content

Commit

Permalink
better typing
Browse files Browse the repository at this point in the history
  • Loading branch information
squi-ddy committed Apr 7, 2024
1 parent bd1178a commit 11c2059
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
6 changes: 3 additions & 3 deletions bot/src/cogs/github_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, bot: Bot, cache: Cache, json_cache: JSONCache) -> None:

self.bot = bot
self.cache = cache
self.github_auth_flows: MutableMapping[str, Tuple[int, int]] = json_cache.register_cache(
self.github_auth_flows = json_cache.register_cache(
"github_auth_flows", self.prune_auth_flows
) # state -> timestamp, discord id

Expand Down Expand Up @@ -109,7 +109,7 @@ async def on_gh_auth_response(self, data) -> Union[str, Tuple[str, int]]:
500,
)

del self.github_auth_flows[params["state"]]
self.github_auth_flows.pop(params["state"])

# get their name
github_user = Github(response.json()["access_token"]).get_user()
Expand Down Expand Up @@ -137,7 +137,7 @@ def prune_auth_flows(self, github_auth_flows: MutableMapping[str, Tuple[int, int
# wrap in list to create a copy of items (we modify the dict in the loop)
for key, auth_flow_data in list(current_auth_flows.items()):
if current_time - auth_flow_data[0] >= 86400:
del github_auth_flows[key]
github_auth_flows.pop(key)


__all__ = ["GithubAuth"]
12 changes: 7 additions & 5 deletions bot/src/cogs/json_cache.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from typing import Any, Callable, MutableMapping, Optional, Tuple
from typing import Any, Callable, MutableMapping, Optional, Tuple, TypeVar

import orjson
from nextcord.ext import tasks
from nextcord.ext.commands import Bot, Cog

logger = logging.getLogger(__name__)
SaveCallback = Callable[[MutableMapping[str, Any]], None]

DataType = TypeVar('DataType')
SaveCallback = Callable[[MutableMapping[str, DataType]], None]


class JSONCache(Cog):
Expand All @@ -19,8 +21,8 @@ def __init__(self, bot: Bot) -> None:
self.json_caches: MutableMapping[str, Tuple[SaveCallback, MutableMapping[str, Any]]] = {}

def register_cache(
self, cache_name: str, do_before_save: Optional[SaveCallback] = None
) -> MutableMapping[str, Any]:
self, cache_name: str, do_before_save: Optional[SaveCallback[DataType]] = None
) -> MutableMapping[str, DataType]:
if not do_before_save:
_do_before_save: SaveCallback = lambda _: None
else:
Expand All @@ -32,7 +34,7 @@ def register_cache(
except FileNotFoundError:
data = b"{}"

cache: MutableMapping[str, Any] = orjson.loads(data)
cache: MutableMapping[str, DataType] = orjson.loads(data)

logger.info(f"Loaded {len(cache)} records in {cache_name}.json")

Expand Down
8 changes: 4 additions & 4 deletions bot/src/cogs/ms_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, bot: Bot, cache: Cache, ui_helper: UIHelper, json_cache: JSON
authority=f"https://login.microsoftonline.com/{config.ms_auth_tenant_id}",
)

self.auth_flows: MutableMapping[str, Tuple[int, int, Any]] = json_cache.register_cache(
self.auth_flows = json_cache.register_cache(
"auth_flows", self.prune_auth_flows
) # state -> timestamp, discord id, flow

Expand Down Expand Up @@ -151,7 +151,7 @@ def get_ms_auth_link(self, member_id: int) -> str:

auth_flow = self.application.initiate_auth_code_flow(
scopes=["User.Read"],
redirect_uri=f"{config.ms_auth_redirect_domain}",
redirect_uri=config.ms_auth_redirect_domain,
state=state,
response_mode="form_post",
)
Expand Down Expand Up @@ -179,7 +179,7 @@ async def on_ms_auth_response(self, data) -> Union[str, Tuple[str, int]]:
500,
)

del self.auth_flows[params["state"]]
self.auth_flows.pop(params["state"])

# get their email and name
user_data = requests.get(
Expand Down Expand Up @@ -324,7 +324,7 @@ def prune_auth_flows(self, auth_flows: MutableMapping[str, Tuple[int, int, Any]]
# wrap in list to create a copy of items (we modify the dict in the loop)
for key, auth_flow_data in list(current_auth_flows.items()):
if current_time - auth_flow_data[0] >= 86400:
del auth_flows[key]
auth_flows.pop(key)


__all__ = ["MSAuth"]
6 changes: 3 additions & 3 deletions bot/src/cogs/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def on_message(self, message: Message) -> None:
self.buttons[str(message.id)] = []

self.buttons[str(message.id)].append((component.custom_id, *self.pending[component.custom_id]))
del self.pending[component.custom_id]
self.pending.pop(component.custom_id)

@Cog.listener()
async def on_raw_message_delete(self, payload: RawMessageDeleteEvent) -> None:
Expand Down Expand Up @@ -151,11 +151,11 @@ def filter_button_id(button: Tuple[str, str, Collection[Any]]) -> bool:
continue

self.buttons[message_id].append((button_id, *self.pending[button_id]))
del self.pending[button_id]
self.pending.pop(button_id)

# if the length is 0, remove the message from the buttons dict
if len(self.buttons[message_id]) == 0:
del self.buttons[message_id]
self.buttons.pop(message_id)

@Cog.listener()
async def on_interaction(self, interaction: Interaction) -> None:
Expand Down

0 comments on commit 11c2059

Please sign in to comment.