Skip to content

Commit

Permalink
Use discordpy's tasks to check RSS
Browse files Browse the repository at this point in the history
  • Loading branch information
raccube committed Aug 16, 2024
1 parent a1d8fb0 commit 73980ef
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 30 deletions.
18 changes: 18 additions & 0 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ANNOUNCE_CHANNEL_NAME,
WELCOME_CATEGORY_NAME,
PASSWORDS_CHANNEL_NAME,
FEED_CHECK_INTERVAL,
)
from src.commands.join import join
from src.commands.team import (
Expand All @@ -26,6 +27,10 @@
create_team_channel,
)

from discord.ext import tasks

from src.rss import check_posts


class BotClient(discord.Client):
logger: logging.Logger
Expand Down Expand Up @@ -66,6 +71,9 @@ async def setup_hook(self) -> None:
self.tree.copy_global_to(guild=self.guild)
await self.tree.sync(guild=self.guild)

async def setup_hook(self) -> None:
self.check_for_new_blog_posts.start()

async def on_ready(self) -> None:
self.logger.info(f"{self.user} has connected to Discord!")
guild = self.get_guild(self.guild.id)
Expand Down Expand Up @@ -134,6 +142,16 @@ async def on_member_remove(self, member: discord.Member) -> None:
await channel.delete()
self.logger.info(f"Deleted channel '{channel.name}', because it has no users.")

@tasks.loop(seconds=FEED_CHECK_INTERVAL)
async def check_for_new_blog_posts(self):
self.logger.info("Checking for new blog posts")
await check_posts(self.get_guild(int(os.getenv('DISCORD_GUILD_ID'))))

@check_for_new_blog_posts.before_loop
async def before_check_for_new_blog_posts(self):
await self.wait_until_ready()


async def load_passwords(self) -> AsyncGenerator[Tuple[str, str], None]:
"""
Returns a mapping from role name to the password for that role.
Expand Down
13 changes: 1 addition & 12 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
import sys
import logging
Expand All @@ -7,7 +6,6 @@
from discord import Intents

from src.bot import BotClient
from rss import post_check_timer

logger = logging.getLogger("srbot")
logger.setLevel(logging.INFO)
Expand All @@ -26,13 +24,4 @@
exit(1)

bot = BotClient(logger=logger, intents=intents)
loop = asyncio.get_event_loop()

try:
loop.create_task(post_check_timer(bot))
loop.run_until_complete(bot.start(token))
except KeyboardInterrupt:
loop.run_until_complete(bot.close())
# cancel all tasks lingering
finally:
loop.close()
bot.run(token)
21 changes: 3 additions & 18 deletions src/rss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
from typing import List

Expand All @@ -7,14 +6,7 @@
from bs4 import BeautifulSoup
from feedparser import FeedParserDict

from src.bot import BotClient
from src.constants import FEED_URL, FEED_CHECK_INTERVAL, FEED_CHANNEL_NAME


def get_feed_channel(bot: BotClient) -> discord.TextChannel:
for channel in bot.get_all_channels():
if channel.name == FEED_CHANNEL_NAME:
return channel
from src.constants import FEED_URL, FEED_CHANNEL_NAME


def get_seen_posts() -> List[str]:
Expand All @@ -30,9 +22,9 @@ def add_seen_post(post_id: str) -> None:
f.write(post_id + '\n')


async def check_posts(bot: BotClient):
async def check_posts(guild: discord.Guild) -> None:
feed = feedparser.parse(FEED_URL)
channel = get_feed_channel(bot)
channel = discord.utils.get(guild.channels, name=FEED_CHANNEL_NAME)
post = feed.entries[0]

if post.id + "\n" not in get_seen_posts():
Expand All @@ -54,10 +46,3 @@ def create_embed(post: FeedParserDict) -> discord.Embed:
embed.set_image(url=post.media_thumbnail[0]['url'])

return embed


async def post_check_timer(bot: BotClient):
await bot.wait_until_ready()
while True:
await check_posts(bot)
await asyncio.sleep(FEED_CHECK_INTERVAL)

0 comments on commit 73980ef

Please sign in to comment.