From 9a02dbfa2ef881647a8184808cb51d3624f45498 Mon Sep 17 00:00:00 2001 From: ZRunner Date: Tue, 17 Oct 2023 18:23:12 -0400 Subject: [PATCH] feat(gaw): add /giveaway edit command --- src/firebase/caching.py | 5 +++ src/firebase/client.py | 10 +++++ src/modules/giveaways/main.py | 82 ++++++++++++++++++++++++++++++++--- src/utils/custom_args.py | 25 ++++++++++- 4 files changed, 116 insertions(+), 6 deletions(-) diff --git a/src/firebase/caching.py b/src/firebase/caching.py index 3a1727a..f0cf82d 100644 --- a/src/firebase/caching.py +++ b/src/firebase/caching.py @@ -77,6 +77,11 @@ def set_existing_giveaway(self, giveaway: GiveawayData): "Set an existing giveaway" self.giveaways_cache[giveaway["id"]] = giveaway + def edit_giveaway(self, giveaway_id: str, partial_giveaway: GiveawayData): + "Edit a giveaway" + if giveaway_id in self.giveaways_cache: + self.giveaways_cache[giveaway_id].update(partial_giveaway) + def close_giveaway(self, giveaway_id: str, winners: list[int]): "Close a giveaway" if giveaway_id in self.giveaways_cache: diff --git a/src/firebase/client.py b/src/firebase/client.py index 9f6b127..86fe64f 100644 --- a/src/firebase/client.py +++ b/src/firebase/client.py @@ -115,6 +115,16 @@ async def delete_giveaway(self, giveaway_id: str): # update cache self.cache.delete_giveaway(giveaway_id) + async def edit_giveaway(self, giveaway_id: str, data: GiveawayData): + "Edit a giveaway document" + self.log.info("Editing giveaway %s", giveaway_id) + ref = db.reference(f"giveaways/{giveaway_id}") + ref.update({ + **data, + "ends_at": data["ends_at"].isoformat() + }) + self.cache.edit_giveaway(giveaway_id, data) + async def get_giveaways_participants(self, giveaway_id: str) -> Optional[list[int]]: "Get a list of participants for a giveaway" if self.cache.are_participants_sync(giveaway_id): diff --git a/src/modules/giveaways/main.py b/src/modules/giveaways/main.py index ef3084f..a28e6d6 100644 --- a/src/modules/giveaways/main.py +++ b/src/modules/giveaways/main.py @@ -1,19 +1,19 @@ import logging import random -from datetime import timedelta +from datetime import datetime, timedelta, timezone from typing import Optional, Union from uuid import uuid4 import discord from apscheduler.schedulers.asyncio import AsyncIOScheduler -from discord.app_commands import Choice +from discord.app_commands import Choice, Range from discord.ext import commands, tasks from src.cobot import CObot, COInteraction -from src.utils.confirm_view import ConfirmView -from src.utils.custom_args import ColorOption, DurationOption from src.modules.giveaways.types import GiveawayData, GiveawayToSendData from src.modules.giveaways.views import GiveawayView +from src.utils.confirm_view import ConfirmView +from src.utils.custom_args import ColorOption, DateOption, DurationOption AcceptableChannel = (discord.TextChannel, discord.Thread, discord.StageChannel, discord.VoiceChannel) AcceptableChannelType = Union[discord.TextChannel, discord.Thread, discord.StageChannel, discord.VoiceChannel] @@ -126,7 +126,7 @@ async def gw_list(self, interaction: COInteraction, *, include_stopped: bool=Fal await interaction.followup.send(embed=embed) @group.command(name="create") - async def gw_create(self, interaction: COInteraction, *, name: str, description: str, + async def gw_create(self, interaction: COInteraction, *, name: Range[str, 2, 30], description: Range[str, 2, 256], duration: DurationOption, channel: Optional[AcceptableChannelType]=None, color: Optional[ColorOption]=None, max_entries: Optional[int]=None, winners_count: int=1): @@ -211,6 +211,59 @@ async def gw_delete_autocomplete(self, interaction: COInteraction, current: str) choices.append((priority, gaw["name"], choice)) return [choice for _, _, choice in sorted(choices, key=lambda x: x[0:2])] + @group.command(name="edit") + async def gw_edit(self, interaction: COInteraction, giveaway: str, *, + name: Optional[ Range[str, 2, 30]]=None, description: Optional[ Range[str, 2, 256]]=None, + utc_end_date: Optional[DateOption]=None, color: Optional[ColorOption]=None, + max_entries: Optional[int]=None, winners_count: Optional[int]=None): + "Edit an existing giveaway" + if interaction.guild is None: + return + if all(arg is None for arg in (name, description, utc_end_date, color, max_entries, winners_count)): + await interaction.response.send_message("You must provide at least one argument to edit!") + return + if utc_end_date is not None and utc_end_date < discord.utils.utcnow(): + await interaction.response.send_message("The end date must be in the future!") + return + await interaction.response.defer() + gaw = await self.bot.fb.get_giveaway(giveaway) + if gaw is None: + await interaction.followup.send("Giveaway not found!") + return + # run basic tests + if gaw["guild"] != interaction.guild.id: + await interaction.followup.send("You can only delete giveaways in your own server!") + return + if gaw["ended"]: + await interaction.followup.send("You can't edit an ended giveaway!") + return + # edit original data + gaw = await self._merge_giveaways_data(gaw, name, description, utc_end_date, color, max_entries, winners_count) + # edit embed + message = await self.fetch_gaw_message(gaw) + if message is None: + await interaction.followup.send("Giveaway message not found!") + return + embed = await self.create_active_gaw_embed(gaw) + await message.edit(embed=embed) + # edit database + await self.bot.fb.edit_giveaway(giveaway, gaw) + await interaction.followup.send("Giveaway edited!") + + @gw_edit.autocomplete("giveaway") + async def gw_edit_autocomplete(self, interaction: COInteraction, current: str): + "Autocomplete for the giveaway argument of the edit command" + if interaction.guild_id is None: + return [] + current = current.lower() + choices: list[tuple[bool, str, Choice[str]]] = [] + async for gaw in self.bot.fb.get_giveaways(): + if gaw["guild"] == interaction.guild_id and current in gaw["name"].lower(): + priority = not gaw["name"].lower().startswith(current) + choice = Choice(name=gaw["name"], value=gaw["id"]) + choices.append((priority, gaw["name"], choice)) + return [choice for _, _, choice in sorted(choices, key=lambda x: x[0:2])] + async def create_active_gaw_embed(self, data: GiveawayToSendData, participants_count: int=0): "Create a Discord embed for an active giveaway" embed = discord.Embed( @@ -327,6 +380,25 @@ async def pick_giveaway_winners(self, data: GiveawayData) -> list[int]: winners_count = min(data["winners_count"], len(participants)) return random.sample(participants, winners_count) + async def _merge_giveaways_data(self, original_data: GiveawayData, + name: Optional[str], description: Optional[str], + utc_end_date: Optional[datetime], color: Optional[discord.Colour], + max_entries: Optional[int], winners_count: Optional[int]) -> GiveawayData: + "Update a given giveaway data with new values" + if name is not None: + original_data["name"] = name + if description is not None: + original_data["description"] = description + if utc_end_date is not None: + original_data["ends_at"] = utc_end_date.astimezone(timezone.utc) + if color is not None: + original_data["color"] = color.value + if max_entries is not None: + original_data["max_entries"] = max_entries + if winners_count is not None: + original_data["winners_count"] = winners_count + return original_data + async def setup(bot: CObot): diff --git a/src/utils/custom_args.py b/src/utils/custom_args.py index b2cc8ab..c1c0c81 100644 --- a/src/utils/custom_args.py +++ b/src/utils/custom_args.py @@ -68,4 +68,27 @@ async def transform(self, interaction: discord.Interaction, value: str) -> int: raise ValueError("Invalid duration") return round(duration) -DurationOption = app_commands.Transform[int, DurationTransformer] \ No newline at end of file +DurationOption = app_commands.Transform[int, DurationTransformer] + +# pylint: disable=abstract-method +class DateTransformer(app_commands.Transformer): + """Transform a string into a UTC datetime.datetime""" + + # pylint: disable=arguments-differ + async def transform(self, interaction: discord.Interaction, value: str) -> datetime: + "Converts a string to a datetime.datetime." + try: + date = datetime.fromisoformat(value) + except ValueError: + try: + date = datetime.strptime(value, "%Y-%m-%d %H:%M") + except ValueError: + try: + date = datetime.strptime(value, "%d/%m/%Y %H:%M") + except ValueError: + raise ValueError("Invalid date") from None + if not date.tzinfo: + date = date.replace(tzinfo=timezone.utc) + return date + +DateOption = app_commands.Transform[datetime, DateTransformer]