diff --git a/bot.py b/bot.py index 45530a908f..3701d826ea 100644 --- a/bot.py +++ b/bot.py @@ -1222,25 +1222,36 @@ async def handle_reaction_events(self, payload): return channel = self.get_channel(payload.channel_id) - if not channel: # dm channel not in internal cache - _thread = await self.threads.find(recipient=user) - if not _thread: + thread = None + # dm channel not in internal cache + if not channel: + thread = await self.threads.find(recipient=user) + if not thread: + return + channel = await thread.recipient.create_dm() + if channel.id != payload.channel_id: + return + + from_dm = isinstance(channel, discord.DMChannel) + from_txt = isinstance(channel, discord.TextChannel) + if not from_dm and not from_txt: + return + + if not thread: + params = {"recipient": user} if from_dm else {"channel": channel} + thread = await self.threads.find(**params) + if not thread: return - channel = await _thread.recipient.create_dm() + # thread must exist before doing this API call try: message = await channel.fetch_message(payload.message_id) except (discord.NotFound, discord.Forbidden): return reaction = payload.emoji - close_emoji = await self.convert_emoji(self.config["close_emoji"]) - - if isinstance(channel, discord.DMChannel): - thread = await self.threads.find(recipient=user) - if not thread: - return + if from_dm: if ( payload.event_type == "REACTION_ADD" and message.embeds @@ -1248,7 +1259,7 @@ async def handle_reaction_events(self, payload): and self.config.get("recipient_thread_close") ): ts = message.embeds[0].timestamp - if thread and ts == thread.channel.created_at: + if ts == thread.channel.created_at: # the reacted message is the corresponding thread creation embed # closing thread return await thread.close(closer=user) @@ -1268,11 +1279,10 @@ async def handle_reaction_events(self, payload): logger.warning("Failed to find linked message for reactions: %s", e) return else: - thread = await self.threads.find(channel=channel) - if not thread: - return try: - _, *linked_messages = await thread.find_linked_messages(message.id, either_direction=True) + _, *linked_messages = await thread.find_linked_messages( + message1=message, either_direction=True + ) except ValueError as e: logger.warning("Failed to find linked message for reactions: %s", e) return